mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-03 11:57:07 +02:00
Move crates under a sub folder to clean up the code
This commit is contained in:
parent
30f3c30389
commit
9c1e54a2c8
1062 changed files with 19 additions and 20 deletions
435
crates/milli/src/vector/error.rs
Normal file
435
crates/milli/src/vector/error.rs
Normal file
|
@ -0,0 +1,435 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use hf_hub::api::sync::ApiError;
|
||||
|
||||
use super::parsed_vectors::ParsedVectorsDiff;
|
||||
use super::rest::ConfigurationSource;
|
||||
use crate::error::FaultSource;
|
||||
use crate::{FieldDistribution, PanicCatched};
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Error while generating embeddings: {inner}")]
|
||||
pub struct Error {
|
||||
pub inner: Box<ErrorKind>,
|
||||
}
|
||||
|
||||
impl<I: Into<ErrorKind>> From<I> for Error {
|
||||
fn from(value: I) -> Self {
|
||||
Self { inner: Box::new(value.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn fault(&self) -> FaultSource {
|
||||
match &*self.inner {
|
||||
ErrorKind::NewEmbedderError(inner) => inner.fault,
|
||||
ErrorKind::EmbedError(inner) => inner.fault,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ErrorKind {
|
||||
#[error(transparent)]
|
||||
NewEmbedderError(#[from] NewEmbedderError),
|
||||
#[error(transparent)]
|
||||
EmbedError(#[from] EmbedError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct EmbedError {
|
||||
pub kind: EmbedErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EmbedErrorKind {
|
||||
#[error("could not tokenize:\n - {0}")]
|
||||
Tokenize(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("unexpected tensor shape:\n - {0}")]
|
||||
TensorShape(candle_core::Error),
|
||||
#[error("unexpected tensor value:\n - {0}")]
|
||||
TensorValue(candle_core::Error),
|
||||
#[error("could not run model:\n - {0}")]
|
||||
ModelForward(candle_core::Error),
|
||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided:\n - `{0}`")]
|
||||
ManualEmbed(String),
|
||||
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually{}", option_info(.0.as_deref(), "server replied with "))]
|
||||
OllamaModelNotFoundError(Option<String>),
|
||||
#[error("error deserialization the response body as JSON:\n - {0}")]
|
||||
RestResponseDeserialization(std::io::Error),
|
||||
#[error("expected a response containing {0} embeddings, got only {1}")]
|
||||
RestResponseEmbeddingCount(usize, usize),
|
||||
#[error("could not authenticate against {embedding} server{server_reply}{hint}", embedding=match *.1 {
|
||||
ConfigurationSource::User => "embedding",
|
||||
ConfigurationSource::OpenAi => "OpenAI",
|
||||
ConfigurationSource::Ollama => "ollama"
|
||||
},
|
||||
server_reply=option_info(.0.as_deref(), "server replied with "),
|
||||
hint=match *.1 {
|
||||
ConfigurationSource::User => "\n - Hint: Check the `apiKey` parameter in the embedder configuration",
|
||||
ConfigurationSource::OpenAi => "\n - Hint: Check the `apiKey` parameter in the embedder configuration, and the `MEILI_OPENAI_API_KEY` and `OPENAI_API_KEY` environment variables",
|
||||
ConfigurationSource::Ollama => "\n - Hint: Check the `apiKey` parameter in the embedder configuration"
|
||||
})]
|
||||
RestUnauthorized(Option<String>, ConfigurationSource),
|
||||
#[error("sent too many requests to embedding server{}", option_info(.0.as_deref(), "server replied with "))]
|
||||
RestTooManyRequests(Option<String>),
|
||||
#[error("sent a bad request to embedding server{}{}",
|
||||
if ConfigurationSource::User == *.1 {
|
||||
"\n - Hint: check that the `request` in the embedder configuration matches the remote server's API"
|
||||
} else {
|
||||
""
|
||||
},
|
||||
option_info(.0.as_deref(), "server replied with "))]
|
||||
RestBadRequest(Option<String>, ConfigurationSource),
|
||||
#[error("received internal error HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
|
||||
RestInternalServerError(u16, Option<String>),
|
||||
#[error("received unexpected HTTP {0} from embedding server{}", option_info(.1.as_deref(), "server replied with "))]
|
||||
RestOtherStatusCode(u16, Option<String>),
|
||||
#[error("could not reach embedding server:\n - {0}")]
|
||||
RestNetwork(ureq::Transport),
|
||||
#[error("error extracting embeddings from the response:\n - {0}")]
|
||||
RestExtractionError(String),
|
||||
#[error("was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
|
||||
UnexpectedDimension(usize, usize),
|
||||
#[error("no embedding was produced")]
|
||||
MissingEmbedding,
|
||||
#[error(transparent)]
|
||||
PanicInThreadPool(#[from] PanicCatched),
|
||||
}
|
||||
|
||||
fn option_info(info: Option<&str>, prefix: &str) -> String {
|
||||
match info {
|
||||
Some(info) => format!("\n - {prefix}`{info}`"),
|
||||
None => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
|
||||
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn tensor_shape(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn tensor_value(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn model_forward(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestResponseDeserialization(error),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_unauthorized(
|
||||
error_response: Option<String>,
|
||||
configuration_source: ConfigurationSource,
|
||||
) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestUnauthorized(error_response, configuration_source),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestTooManyRequests(error_response),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_bad_request(
|
||||
error_response: Option<String>,
|
||||
configuration_source: ConfigurationSource,
|
||||
) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestBadRequest(error_response, configuration_source),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_internal_server_error(
|
||||
code: u16,
|
||||
error_response: Option<String>,
|
||||
) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
|
||||
fault: FaultSource::Undecided,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::UnexpectedDimension(expected, got),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
pub(crate) fn missing_embedding() -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::MissingEmbedding, fault: FaultSource::Undecided }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_extraction_error(error: String) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestExtractionError(error), fault: FaultSource::Runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct NewEmbedderError {
|
||||
pub kind: NewEmbedderErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
impl NewEmbedderError {
|
||||
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
|
||||
let open_config = OpenConfig { filename: config_filename, inner };
|
||||
|
||||
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn deserialize_config(
|
||||
model_name: String,
|
||||
config: String,
|
||||
config_filename: PathBuf,
|
||||
inner: serde_json::Error,
|
||||
) -> NewEmbedderError {
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(value) => {
|
||||
let value: serde_json::Value = value;
|
||||
let architectures = match value.get("architectures") {
|
||||
Some(serde_json::Value::Array(architectures)) => architectures
|
||||
.iter()
|
||||
.filter_map(|value| match value {
|
||||
serde_json::Value::String(s) => Some(s.to_owned()),
|
||||
_ => None,
|
||||
})
|
||||
.collect(),
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
let unsupported_model = UnsupportedModel { model_name, inner, architectures };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::UnsupportedModel(unsupported_model),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let deserialize_config =
|
||||
DeserializeConfig { model_name, filename: config_filename, inner: error };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_tokenizer(
|
||||
tokenizer_filename: PathBuf,
|
||||
inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
) -> NewEmbedderError {
|
||||
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_api_fail(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn api_get(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
|
||||
}
|
||||
|
||||
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::SafetensorWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn load_model(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_could_not_parse_template(message: String) -> NewEmbedderError {
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::CouldNotParseTemplate(message),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open config at {filename}: {inner}")]
|
||||
pub struct OpenConfig {
|
||||
pub filename: PathBuf,
|
||||
pub inner: std::io::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
|
||||
pub struct DeserializeConfig {
|
||||
pub model_name: String,
|
||||
pub filename: PathBuf,
|
||||
pub inner: serde_json::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
||||
if architectures.is_empty() {
|
||||
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
|
||||
} else {
|
||||
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
|
||||
})]
|
||||
pub struct UnsupportedModel {
|
||||
pub model_name: String,
|
||||
pub inner: serde_json::Error,
|
||||
pub architectures: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open tokenizer at {filename}: {inner}")]
|
||||
pub struct OpenTokenizer {
|
||||
pub filename: PathBuf,
|
||||
#[source]
|
||||
pub inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NewEmbedderErrorKind {
|
||||
// hf
|
||||
#[error(transparent)]
|
||||
OpenConfig(OpenConfig),
|
||||
#[error(transparent)]
|
||||
DeserializeConfig(DeserializeConfig),
|
||||
#[error(transparent)]
|
||||
UnsupportedModel(UnsupportedModel),
|
||||
#[error(transparent)]
|
||||
OpenTokenizer(OpenTokenizer),
|
||||
#[error("could not build weights from Pytorch weights:\n - {0}")]
|
||||
PytorchWeight(candle_core::Error),
|
||||
#[error("could not build weights from Safetensor weights:\n - {0}")]
|
||||
SafetensorWeight(candle_core::Error),
|
||||
#[error("could not spawn HG_HUB API client:\n - {0}")]
|
||||
NewApiFail(ApiError),
|
||||
#[error("fetching file from HG_HUB failed:\n - {0}")]
|
||||
ApiGet(ApiError),
|
||||
#[error("could not determine model dimensions:\n - test embedding failed with {0}")]
|
||||
CouldNotDetermineDimension(EmbedError),
|
||||
#[error("loading model failed:\n - {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
#[error("{0}")]
|
||||
CouldNotParseTemplate(String),
|
||||
}
|
||||
|
||||
pub struct PossibleEmbeddingMistakes {
|
||||
vectors_mistakes: BTreeMap<String, u64>,
|
||||
}
|
||||
|
||||
impl PossibleEmbeddingMistakes {
|
||||
pub fn new(field_distribution: &FieldDistribution) -> Self {
|
||||
let mut vectors_mistakes = BTreeMap::new();
|
||||
let builder = levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true);
|
||||
let automata = builder.build_dfa("_vectors");
|
||||
for (field, count) in field_distribution {
|
||||
if *count == 0 {
|
||||
continue;
|
||||
}
|
||||
if field.contains('.') {
|
||||
continue;
|
||||
}
|
||||
match automata.eval(field) {
|
||||
levenshtein_automata::Distance::Exact(0) => continue,
|
||||
levenshtein_automata::Distance::Exact(_) => {
|
||||
vectors_mistakes.insert(field.to_string(), *count);
|
||||
}
|
||||
levenshtein_automata::Distance::AtLeast(_) => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Self { vectors_mistakes }
|
||||
}
|
||||
|
||||
pub fn vector_mistakes(&self) -> impl Iterator<Item = (&str, u64)> {
|
||||
self.vectors_mistakes.iter().map(|(misspelling, count)| (misspelling.as_str(), *count))
|
||||
}
|
||||
|
||||
pub fn embedder_mistakes<'a>(
|
||||
&'a self,
|
||||
embedder_name: &'a str,
|
||||
unused_vectors_distributions: &'a UnusedVectorsDistribution,
|
||||
) -> impl Iterator<Item = (&'a str, u64)> + 'a {
|
||||
let builder = levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true);
|
||||
let automata = builder.build_dfa(embedder_name);
|
||||
|
||||
unused_vectors_distributions.0.iter().filter_map(move |(field, count)| {
|
||||
match automata.eval(field) {
|
||||
levenshtein_automata::Distance::Exact(0) => None,
|
||||
levenshtein_automata::Distance::Exact(_) => Some((field.as_str(), *count)),
|
||||
levenshtein_automata::Distance::AtLeast(_) => None,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct UnusedVectorsDistribution(BTreeMap<String, u64>);
|
||||
|
||||
impl UnusedVectorsDistribution {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn append(&mut self, parsed_vectors_diff: ParsedVectorsDiff) {
|
||||
for name in parsed_vectors_diff.into_new_vectors_keys_iter() {
|
||||
*self.0.entry(name).or_default() += 1;
|
||||
}
|
||||
}
|
||||
}
|
214
crates/milli/src/vector/hf.rs
Normal file
214
crates/milli/src/vector/hf.rs
Normal file
|
@ -0,0 +1,214 @@
|
|||
use candle_core::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
||||
use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
Default,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
serde::Deserialize,
|
||||
serde::Serialize,
|
||||
deserr::Deserr,
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
enum WeightSource {
|
||||
#[default]
|
||||
Safetensors,
|
||||
Pytorch,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub model: String,
|
||||
pub revision: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
||||
distribution: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform embedding of documents and queries
|
||||
pub struct Embedder {
|
||||
model: BertModel,
|
||||
tokenizer: Tokenizer,
|
||||
options: EmbedderOptions,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Embedder {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Embedder")
|
||||
.field("model", &self.options.model)
|
||||
.field("tokenizer", &self.tokenizer)
|
||||
.field("options", &self.options)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
let device = match candle_core::Device::cuda_if_available(0) {
|
||||
Ok(device) => device,
|
||||
Err(error) => {
|
||||
tracing::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
|
||||
candle_core::Device::Cpu
|
||||
}
|
||||
};
|
||||
let repo = match options.revision.clone() {
|
||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||
None => Repo::model(options.model.clone()),
|
||||
};
|
||||
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
|
||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||
let (weights, source) = {
|
||||
api.get("model.safetensors")
|
||||
.map(|filename| (filename, WeightSource::Safetensors))
|
||||
.or_else(|_| {
|
||||
api.get("pytorch_model.bin")
|
||||
.map(|filename| (filename, WeightSource::Pytorch))
|
||||
})
|
||||
.map_err(NewEmbedderError::api_get)?
|
||||
};
|
||||
(config, tokenizer, weights, source)
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(&config_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config,
|
||||
config_filename,
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
||||
let vb = match weight_source {
|
||||
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||
WeightSource::Safetensors => unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
|
||||
.map_err(NewEmbedderError::safetensor_weight)?
|
||||
},
|
||||
};
|
||||
|
||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
|
||||
let mut this = Self { model, tokenizer, options, dimensions: 0 };
|
||||
|
||||
let embeddings = this
|
||||
.embed(vec!["test".into()])
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
this.dimensions = embeddings.first().unwrap().dimension();
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn embed(
|
||||
&self,
|
||||
mut texts: Vec<String>,
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let tokens = match texts.len() {
|
||||
1 => vec![self
|
||||
.tokenizer
|
||||
.encode(texts.pop().unwrap(), true)
|
||||
.map_err(EmbedError::tokenize)?],
|
||||
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
|
||||
};
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
tokens.truncate(512);
|
||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.distribution.or_else(|| {
|
||||
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
||||
Some(DistributionShift {
|
||||
current_mean: ordered_float::OrderedFloat(0.85),
|
||||
current_sigma: ordered_float::OrderedFloat(0.1),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
970
crates/milli/src/vector/json_template.rs
Normal file
970
crates/milli/src/vector/json_template.rs
Normal file
|
@ -0,0 +1,970 @@
|
|||
//! Module to manipulate JSON templates.
|
||||
//!
|
||||
//! This module allows two main operations:
|
||||
//! 1. Render JSON values from a template and a context value.
|
||||
//! 2. Retrieve data from a template and JSON values.
|
||||
|
||||
#![warn(rustdoc::broken_intra_doc_links)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
type ValuePath = Vec<PathComponent>;
|
||||
|
||||
/// Encapsulates a JSON template and allows injecting and extracting values from it.
|
||||
#[derive(Debug)]
|
||||
pub struct ValueTemplate {
|
||||
template: Value,
|
||||
value_kind: ValueKind,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ValueKind {
|
||||
Single(ValuePath),
|
||||
Array(ArrayPath),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ArrayPath {
|
||||
repeated_value: Value,
|
||||
path_to_array: ValuePath,
|
||||
value_path_in_array: ValuePath,
|
||||
}
|
||||
|
||||
/// Component of a path to a Value
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PathComponent {
|
||||
/// A key inside of an object
|
||||
MapKey(String),
|
||||
/// An index inside of an array
|
||||
ArrayIndex(usize),
|
||||
}
|
||||
|
||||
impl PartialEq for PathComponent {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(Self::MapKey(l0), Self::MapKey(r0)) => l0 == r0,
|
||||
(Self::ArrayIndex(l0), Self::ArrayIndex(r0)) => l0 == r0,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for PathComponent {}
|
||||
|
||||
/// Error that occurs when no few value was provided to a template for injection.
|
||||
#[derive(Debug)]
|
||||
pub struct MissingValue;
|
||||
|
||||
/// Error that occurs when trying to parse a template in [`ValueTemplate::new`]
|
||||
#[derive(Debug)]
|
||||
pub enum TemplateParsingError {
|
||||
/// A repeat string appears inside a repeated value
|
||||
NestedRepeatString(ValuePath),
|
||||
/// A repeat string appears outside of an array
|
||||
RepeatStringNotInArray(ValuePath),
|
||||
/// A repeat string appears in an array, but not in the second position
|
||||
BadIndexForRepeatString(ValuePath, usize),
|
||||
/// A repeated value lacks a placeholder
|
||||
MissingPlaceholderInRepeatedValue(ValuePath),
|
||||
/// Multiple repeat string appear in the template
|
||||
MultipleRepeatString(ValuePath, ValuePath),
|
||||
/// Multiple placeholder strings appear in the template
|
||||
MultiplePlaceholderString(ValuePath, ValuePath),
|
||||
/// No placeholder string appear in the template
|
||||
MissingPlaceholderString,
|
||||
/// A placeholder appears both inside a repeated value and outside of it
|
||||
BothArrayAndSingle {
|
||||
/// Path to the single value
|
||||
single_path: ValuePath,
|
||||
/// Path to the array of repeated values
|
||||
path_to_array: ValuePath,
|
||||
/// Path to placeholder inside each repeated value, starting from the array
|
||||
array_to_placeholder: ValuePath,
|
||||
},
|
||||
}
|
||||
|
||||
impl TemplateParsingError {
|
||||
/// Produce an error message from the error kind, the name of the root object, the placeholder string and the repeat string
|
||||
pub fn error_message(&self, root: &str, placeholder: &str, repeat: &str) -> String {
|
||||
match self {
|
||||
TemplateParsingError::NestedRepeatString(path) => {
|
||||
format!(
|
||||
r#"in {}: "{repeat}" appears nested inside of a value that is itself repeated"#,
|
||||
path_with_root(root, path)
|
||||
)
|
||||
}
|
||||
TemplateParsingError::RepeatStringNotInArray(path) => format!(
|
||||
r#"in {}: "{repeat}" appears outside of an array"#,
|
||||
path_with_root(root, path)
|
||||
),
|
||||
TemplateParsingError::BadIndexForRepeatString(path, index) => format!(
|
||||
r#"in {}: "{repeat}" expected at position #1, but found at position #{index}"#,
|
||||
path_with_root(root, path)
|
||||
),
|
||||
TemplateParsingError::MissingPlaceholderInRepeatedValue(path) => format!(
|
||||
r#"in {}: Expected "{placeholder}" inside of the repeated value"#,
|
||||
path_with_root(root, path)
|
||||
),
|
||||
TemplateParsingError::MultipleRepeatString(current, previous) => format!(
|
||||
r#"in {}: Found "{repeat}", but it was already present in {}"#,
|
||||
path_with_root(root, current),
|
||||
path_with_root(root, previous)
|
||||
),
|
||||
TemplateParsingError::MultiplePlaceholderString(current, previous) => format!(
|
||||
r#"in {}: Found "{placeholder}", but it was already present in {}"#,
|
||||
path_with_root(root, current),
|
||||
path_with_root(root, previous)
|
||||
),
|
||||
TemplateParsingError::MissingPlaceholderString => {
|
||||
format!(r#"in `{root}`: "{placeholder}" not found"#)
|
||||
}
|
||||
TemplateParsingError::BothArrayAndSingle {
|
||||
single_path,
|
||||
path_to_array,
|
||||
array_to_placeholder,
|
||||
} => {
|
||||
let path_to_first_repeated = path_to_array
|
||||
.iter()
|
||||
.chain(std::iter::once(&PathComponent::ArrayIndex(0)))
|
||||
.chain(array_to_placeholder.iter());
|
||||
format!(
|
||||
r#"in {}: Found "{placeholder}", but it was already present in {} (repeated)"#,
|
||||
path_with_root(root, single_path),
|
||||
path_with_root(root, path_to_first_repeated)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn prepend_path(self, mut prepended_path: ValuePath) -> Self {
|
||||
match self {
|
||||
TemplateParsingError::NestedRepeatString(mut path) => {
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::NestedRepeatString(prepended_path)
|
||||
}
|
||||
TemplateParsingError::RepeatStringNotInArray(mut path) => {
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::RepeatStringNotInArray(prepended_path)
|
||||
}
|
||||
TemplateParsingError::BadIndexForRepeatString(mut path, index) => {
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::BadIndexForRepeatString(prepended_path, index)
|
||||
}
|
||||
TemplateParsingError::MissingPlaceholderInRepeatedValue(mut path) => {
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::MissingPlaceholderInRepeatedValue(prepended_path)
|
||||
}
|
||||
TemplateParsingError::MultipleRepeatString(mut path, older_path) => {
|
||||
let older_prepended_path =
|
||||
prepended_path.iter().cloned().chain(older_path).collect();
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::MultipleRepeatString(prepended_path, older_prepended_path)
|
||||
}
|
||||
TemplateParsingError::MultiplePlaceholderString(mut path, older_path) => {
|
||||
let older_prepended_path =
|
||||
prepended_path.iter().cloned().chain(older_path).collect();
|
||||
prepended_path.append(&mut path);
|
||||
TemplateParsingError::MultiplePlaceholderString(
|
||||
prepended_path,
|
||||
older_prepended_path,
|
||||
)
|
||||
}
|
||||
TemplateParsingError::MissingPlaceholderString => {
|
||||
TemplateParsingError::MissingPlaceholderString
|
||||
}
|
||||
TemplateParsingError::BothArrayAndSingle {
|
||||
single_path,
|
||||
mut path_to_array,
|
||||
array_to_placeholder,
|
||||
} => {
|
||||
// note, this case is not super logical, but is also likely to be dead code
|
||||
let single_prepended_path =
|
||||
prepended_path.iter().cloned().chain(single_path).collect();
|
||||
prepended_path.append(&mut path_to_array);
|
||||
// we don't prepend the array_to_placeholder path as it is the array path that is prepended
|
||||
TemplateParsingError::BothArrayAndSingle {
|
||||
single_path: single_prepended_path,
|
||||
path_to_array: prepended_path,
|
||||
array_to_placeholder,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error that occurs when [`ValueTemplate::extract`] fails.
|
||||
#[derive(Debug)]
|
||||
pub struct ExtractionError {
|
||||
/// The cause of the failure
|
||||
pub kind: ExtractionErrorKind,
|
||||
/// The context where the failure happened: the operation that failed
|
||||
pub context: ExtractionErrorContext,
|
||||
}
|
||||
|
||||
impl ExtractionError {
|
||||
/// Produce an error message from the error, the name of the root object, the placeholder string and the expected value type
|
||||
pub fn error_message(
|
||||
&self,
|
||||
root: &str,
|
||||
placeholder: &str,
|
||||
expected_value_type: &str,
|
||||
) -> String {
|
||||
let context = match &self.context {
|
||||
ExtractionErrorContext::ExtractingSingleValue => {
|
||||
format!(r#"extracting a single "{placeholder}""#)
|
||||
}
|
||||
ExtractionErrorContext::FindingPathToArray => {
|
||||
format!(r#"extracting the array of "{placeholder}"s"#)
|
||||
}
|
||||
ExtractionErrorContext::ExtractingArrayItem(index) => {
|
||||
format!(r#"extracting item #{index} from the array of "{placeholder}"s"#)
|
||||
}
|
||||
};
|
||||
match &self.kind {
|
||||
ExtractionErrorKind::MissingPathComponent { missing_index, path, key_suggestion } => {
|
||||
let last_named_object = last_named_object(root, path.iter().take(*missing_index));
|
||||
format!(
|
||||
"in {}, while {context}, configuration expects {}, which is missing in response{}",
|
||||
path_with_root(root, path.iter().take(*missing_index)),
|
||||
missing_component(path.get(*missing_index)),
|
||||
match key_suggestion {
|
||||
Some(key_suggestion) => format!("\n - Hint: {last_named_object} has key `{key_suggestion}`, did you mean {} in embedder configuration?",
|
||||
path_with_root(root, path.iter().take(*missing_index).chain(std::iter::once(&PathComponent::MapKey(key_suggestion.to_owned()))))),
|
||||
None => "".to_owned(),
|
||||
}
|
||||
)
|
||||
}
|
||||
ExtractionErrorKind::WrongPathComponent { wrong_component, index, path } => {
|
||||
let last_named_object = last_named_object(root, path.iter().take(*index));
|
||||
format!(
|
||||
"in {}, while {context}, configuration expects {last_named_object} to be {} but server sent {wrong_component}",
|
||||
path_with_root(root, path.iter().take(*index)),
|
||||
expected_component(path.get(*index))
|
||||
)
|
||||
}
|
||||
ExtractionErrorKind::DeserializationError { error, path } => {
|
||||
let last_named_object = last_named_object(root, path);
|
||||
format!(
|
||||
"in {}, while {context}, expected {last_named_object} to be {expected_value_type}, but failed to parse server response:\n - {error}",
|
||||
path_with_root(root, path)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn missing_component(component: Option<&PathComponent>) -> String {
|
||||
match component {
|
||||
Some(PathComponent::ArrayIndex(index)) => {
|
||||
format!(r#"item #{index}"#)
|
||||
}
|
||||
Some(PathComponent::MapKey(key)) => {
|
||||
format!(r#"key "{key}""#)
|
||||
}
|
||||
None => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn expected_component(component: Option<&PathComponent>) -> String {
|
||||
match component {
|
||||
Some(PathComponent::ArrayIndex(index)) => {
|
||||
format!(r#"an array with at least {} item(s)"#, index.saturating_add(1))
|
||||
}
|
||||
Some(PathComponent::MapKey(key)) => {
|
||||
format!("an object with key `{}`", key)
|
||||
}
|
||||
None => "unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn last_named_object<'a>(
|
||||
root: &'a str,
|
||||
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
|
||||
) -> LastNamedObject<'a> {
|
||||
let mut last_named_object = LastNamedObject::Object { name: root };
|
||||
for component in path.into_iter() {
|
||||
last_named_object = match (component, last_named_object) {
|
||||
(PathComponent::MapKey(name), _) => LastNamedObject::Object { name },
|
||||
(PathComponent::ArrayIndex(index), LastNamedObject::Object { name }) => {
|
||||
LastNamedObject::ArrayInsideObject { object_name: name, index: *index }
|
||||
}
|
||||
(
|
||||
PathComponent::ArrayIndex(index),
|
||||
LastNamedObject::ArrayInsideObject { object_name, index: _ },
|
||||
) => LastNamedObject::NestedArrayInsideObject {
|
||||
object_name,
|
||||
index: *index,
|
||||
nesting_level: 0,
|
||||
},
|
||||
(
|
||||
PathComponent::ArrayIndex(index),
|
||||
LastNamedObject::NestedArrayInsideObject { object_name, index: _, nesting_level },
|
||||
) => LastNamedObject::NestedArrayInsideObject {
|
||||
object_name,
|
||||
index: *index,
|
||||
nesting_level: nesting_level.saturating_add(1),
|
||||
},
|
||||
}
|
||||
}
|
||||
last_named_object
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Display for LastNamedObject<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
LastNamedObject::Object { name } => write!(f, "`{name}`"),
|
||||
LastNamedObject::ArrayInsideObject { object_name, index } => {
|
||||
write!(f, "item #{index} inside `{object_name}`")
|
||||
}
|
||||
LastNamedObject::NestedArrayInsideObject { object_name, index, nesting_level } => {
|
||||
if *nesting_level == 0 {
|
||||
write!(f, "item #{index} inside nested array in `{object_name}`")
|
||||
} else {
|
||||
write!(f, "item #{index} inside nested array ({} levels of nesting) in `{object_name}`", nesting_level + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum LastNamedObject<'a> {
|
||||
Object { name: &'a str },
|
||||
ArrayInsideObject { object_name: &'a str, index: usize },
|
||||
NestedArrayInsideObject { object_name: &'a str, index: usize, nesting_level: usize },
|
||||
}
|
||||
|
||||
/// Builds a string representation of a path, preprending the name of the root value.
|
||||
pub fn path_with_root<'a>(
|
||||
root: &str,
|
||||
path: impl IntoIterator<Item = &'a PathComponent> + 'a,
|
||||
) -> String {
|
||||
use std::fmt::Write as _;
|
||||
let mut res = format!("`{root}");
|
||||
for component in path.into_iter() {
|
||||
match component {
|
||||
PathComponent::MapKey(key) => {
|
||||
let _ = write!(&mut res, ".{key}");
|
||||
}
|
||||
PathComponent::ArrayIndex(index) => {
|
||||
let _ = write!(&mut res, "[{index}]");
|
||||
}
|
||||
}
|
||||
}
|
||||
res.push('`');
|
||||
res
|
||||
}
|
||||
|
||||
/// Context where an extraction failure happened
|
||||
///
|
||||
/// The operation that failed
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ExtractionErrorContext {
|
||||
/// Failure happened while extracting a value at a single location
|
||||
ExtractingSingleValue,
|
||||
/// Failure happened while extracting an array of values
|
||||
FindingPathToArray,
|
||||
/// Failure happened while extracting a value inside of an array
|
||||
ExtractingArrayItem(usize),
|
||||
}
|
||||
|
||||
/// Kind of errors that can happen during extraction
|
||||
#[derive(Debug)]
|
||||
pub enum ExtractionErrorKind {
|
||||
/// An expected path component is missing
|
||||
MissingPathComponent {
|
||||
/// Index of the missing component in the path
|
||||
missing_index: usize,
|
||||
/// Path where a component is missing
|
||||
path: ValuePath,
|
||||
/// Possible matching key in object
|
||||
key_suggestion: Option<String>,
|
||||
},
|
||||
/// An expected path component cannot be found because its container is the wrong type
|
||||
WrongPathComponent {
|
||||
/// String representation of the wrong component
|
||||
wrong_component: String,
|
||||
/// Index of the wrong component in the path
|
||||
index: usize,
|
||||
/// Path where a component has the wrong type
|
||||
path: ValuePath,
|
||||
},
|
||||
/// Could not deserialize an extracted value to its requested type
|
||||
DeserializationError {
|
||||
/// inner deserialization error
|
||||
error: serde_json::Error,
|
||||
/// path to extracted value
|
||||
path: ValuePath,
|
||||
},
|
||||
}
|
||||
|
||||
enum ArrayParsingContext<'a> {
|
||||
Nested,
|
||||
NotNested(&'a mut Option<ArrayPath>),
|
||||
}
|
||||
|
||||
impl ValueTemplate {
|
||||
/// Prepare a template for injection or extraction.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `template`: JSON value that acts a template. Its placeholder values will be replaced by actual values during injection,
|
||||
/// and actual values will be recovered from their location during extraction.
|
||||
/// - `placeholder_string`: Value that a JSON string should assume to act as a placeholder value that can be injected into or
|
||||
/// extracted from.
|
||||
/// - `repeat_string`: Sentinel value that can be placed as the second value in an array to indicate that the first value can be repeated
|
||||
/// any number of times. The first value should contain exactly one placeholder string.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - [`TemplateParsingError`]: refer to the documentation of this type
|
||||
pub fn new(
|
||||
template: Value,
|
||||
placeholder_string: &str,
|
||||
repeat_string: &str,
|
||||
) -> Result<Self, TemplateParsingError> {
|
||||
let mut value_path = None;
|
||||
let mut array_path = None;
|
||||
let mut current_path = Vec::new();
|
||||
Self::parse_value(
|
||||
&template,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
&mut value_path,
|
||||
&mut ArrayParsingContext::NotNested(&mut array_path),
|
||||
&mut current_path,
|
||||
)?;
|
||||
|
||||
let value_kind = match (array_path, value_path) {
|
||||
(None, None) => return Err(TemplateParsingError::MissingPlaceholderString),
|
||||
(None, Some(value_path)) => ValueKind::Single(value_path),
|
||||
(Some(array_path), None) => ValueKind::Array(array_path),
|
||||
(Some(array_path), Some(value_path)) => {
|
||||
return Err(TemplateParsingError::BothArrayAndSingle {
|
||||
single_path: value_path,
|
||||
path_to_array: array_path.path_to_array,
|
||||
array_to_placeholder: array_path.value_path_in_array,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { template, value_kind })
|
||||
}
|
||||
|
||||
/// Whether there is a placeholder that can be repeated.
|
||||
///
|
||||
/// - During injection, all values are injected in the array placeholder,
|
||||
/// - During extraction, all repeatable placeholders are extracted from the array.
|
||||
pub fn has_array_value(&self) -> bool {
|
||||
matches!(self.value_kind, ValueKind::Array(_))
|
||||
}
|
||||
|
||||
/// Render a value from the template and context values.
|
||||
///
|
||||
/// # Error
|
||||
///
|
||||
/// - [`MissingValue`]: if the number of injected values is 0.
|
||||
pub fn inject(&self, values: impl IntoIterator<Item = Value>) -> Result<Value, MissingValue> {
|
||||
let mut rendered = self.template.clone();
|
||||
let mut values = values.into_iter();
|
||||
|
||||
match &self.value_kind {
|
||||
ValueKind::Single(injection_path) => {
|
||||
let Some(injected_value) = values.next() else { return Err(MissingValue) };
|
||||
inject_value(&mut rendered, injection_path, injected_value);
|
||||
}
|
||||
ValueKind::Array(ArrayPath { repeated_value, path_to_array, value_path_in_array }) => {
|
||||
// 1. build the array of repeated values
|
||||
let mut array = Vec::new();
|
||||
for injected_value in values {
|
||||
let mut repeated_value = repeated_value.clone();
|
||||
inject_value(&mut repeated_value, value_path_in_array, injected_value);
|
||||
array.push(repeated_value);
|
||||
}
|
||||
|
||||
if array.is_empty() {
|
||||
return Err(MissingValue);
|
||||
}
|
||||
// 2. inject at the injection point in the rendered value
|
||||
inject_value(&mut rendered, path_to_array, Value::Array(array));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(rendered)
|
||||
}
|
||||
|
||||
/// Extract sub values from the template and a value.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - if a single placeholder is missing.
|
||||
/// - if there is no value corresponding to an array placeholder
|
||||
/// - if the value corresponding to an array placeholder is not an array
|
||||
pub fn extract<T>(&self, mut value: Value) -> Result<Vec<T>, ExtractionError>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
Ok(match &self.value_kind {
|
||||
ValueKind::Single(extraction_path) => {
|
||||
let extracted_value =
|
||||
extract_value(extraction_path, &mut value).with_context(|kind| {
|
||||
ExtractionError {
|
||||
kind,
|
||||
context: ExtractionErrorContext::ExtractingSingleValue,
|
||||
}
|
||||
})?;
|
||||
vec![extracted_value]
|
||||
}
|
||||
ValueKind::Array(ArrayPath {
|
||||
repeated_value: _,
|
||||
path_to_array,
|
||||
value_path_in_array,
|
||||
}) => {
|
||||
// get the array
|
||||
let array = extract_value(path_to_array, &mut value).with_context(|kind| {
|
||||
ExtractionError { kind, context: ExtractionErrorContext::FindingPathToArray }
|
||||
})?;
|
||||
let array = match array {
|
||||
Value::Array(array) => array,
|
||||
not_array => {
|
||||
let mut path = path_to_array.clone();
|
||||
path.push(PathComponent::ArrayIndex(0));
|
||||
return Err(ExtractionError {
|
||||
kind: ExtractionErrorKind::WrongPathComponent {
|
||||
wrong_component: format_value(¬_array),
|
||||
index: path_to_array.len(),
|
||||
path,
|
||||
},
|
||||
context: ExtractionErrorContext::FindingPathToArray,
|
||||
});
|
||||
}
|
||||
};
|
||||
let mut extracted_values = Vec::with_capacity(array.len());
|
||||
|
||||
for (index, mut item) in array.into_iter().enumerate() {
|
||||
let extracted_value = extract_value(value_path_in_array, &mut item)
|
||||
.with_context(|kind| ExtractionError {
|
||||
kind,
|
||||
context: ExtractionErrorContext::ExtractingArrayItem(index),
|
||||
})?;
|
||||
extracted_values.push(extracted_value);
|
||||
}
|
||||
|
||||
extracted_values
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_array(
|
||||
array: &[Value],
|
||||
placeholder_string: &str,
|
||||
repeat_string: &str,
|
||||
value_path: &mut Option<ValuePath>,
|
||||
mut array_path: &mut ArrayParsingContext,
|
||||
current_path: &mut ValuePath,
|
||||
) -> Result<(), TemplateParsingError> {
|
||||
// two modes for parsing array.
|
||||
match array {
|
||||
// 1. array contains a repeat string in second position
|
||||
[first, second, rest @ ..] if second == repeat_string => {
|
||||
let ArrayParsingContext::NotNested(array_path) = &mut array_path else {
|
||||
return Err(TemplateParsingError::NestedRepeatString(current_path.clone()));
|
||||
};
|
||||
if let Some(array_path) = array_path {
|
||||
return Err(TemplateParsingError::MultipleRepeatString(
|
||||
current_path.clone(),
|
||||
array_path.path_to_array.clone(),
|
||||
));
|
||||
}
|
||||
if first == repeat_string {
|
||||
return Err(TemplateParsingError::BadIndexForRepeatString(
|
||||
current_path.clone(),
|
||||
0,
|
||||
));
|
||||
}
|
||||
if let Some(position) = rest.iter().position(|value| value == repeat_string) {
|
||||
let position = position + 2;
|
||||
return Err(TemplateParsingError::BadIndexForRepeatString(
|
||||
current_path.clone(),
|
||||
position,
|
||||
));
|
||||
}
|
||||
|
||||
let value_path_in_array = {
|
||||
let mut value_path = None;
|
||||
let mut current_path_in_array = Vec::new();
|
||||
|
||||
Self::parse_value(
|
||||
first,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
&mut value_path,
|
||||
&mut ArrayParsingContext::Nested,
|
||||
&mut current_path_in_array,
|
||||
)
|
||||
.map_err(|error| error.prepend_path(current_path.to_vec()))?;
|
||||
|
||||
value_path.ok_or_else(|| {
|
||||
let mut repeated_value_path = current_path.clone();
|
||||
repeated_value_path.push(PathComponent::ArrayIndex(0));
|
||||
TemplateParsingError::MissingPlaceholderInRepeatedValue(repeated_value_path)
|
||||
})?
|
||||
};
|
||||
**array_path = Some(ArrayPath {
|
||||
repeated_value: first.to_owned(),
|
||||
path_to_array: current_path.clone(),
|
||||
value_path_in_array,
|
||||
});
|
||||
}
|
||||
// 2. array does not contain a repeat string
|
||||
array => {
|
||||
if let Some(position) = array.iter().position(|value| value == repeat_string) {
|
||||
return Err(TemplateParsingError::BadIndexForRepeatString(
|
||||
current_path.clone(),
|
||||
position,
|
||||
));
|
||||
}
|
||||
for (index, value) in array.iter().enumerate() {
|
||||
current_path.push(PathComponent::ArrayIndex(index));
|
||||
Self::parse_value(
|
||||
value,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
value_path,
|
||||
array_path,
|
||||
current_path,
|
||||
)?;
|
||||
current_path.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_object(
|
||||
object: &Map<String, Value>,
|
||||
placeholder_string: &str,
|
||||
repeat_string: &str,
|
||||
value_path: &mut Option<ValuePath>,
|
||||
array_path: &mut ArrayParsingContext,
|
||||
current_path: &mut ValuePath,
|
||||
) -> Result<(), TemplateParsingError> {
|
||||
for (key, value) in object.iter() {
|
||||
current_path.push(PathComponent::MapKey(key.to_owned()));
|
||||
Self::parse_value(
|
||||
value,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
value_path,
|
||||
array_path,
|
||||
current_path,
|
||||
)?;
|
||||
current_path.pop();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_value(
|
||||
value: &Value,
|
||||
placeholder_string: &str,
|
||||
repeat_string: &str,
|
||||
value_path: &mut Option<ValuePath>,
|
||||
array_path: &mut ArrayParsingContext,
|
||||
current_path: &mut ValuePath,
|
||||
) -> Result<(), TemplateParsingError> {
|
||||
match value {
|
||||
Value::String(str) => {
|
||||
if placeholder_string == str {
|
||||
if let Some(value_path) = value_path {
|
||||
return Err(TemplateParsingError::MultiplePlaceholderString(
|
||||
current_path.clone(),
|
||||
value_path.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
*value_path = Some(current_path.clone());
|
||||
}
|
||||
if repeat_string == str {
|
||||
return Err(TemplateParsingError::RepeatStringNotInArray(current_path.clone()));
|
||||
}
|
||||
}
|
||||
Value::Null | Value::Bool(_) | Value::Number(_) => {}
|
||||
Value::Array(array) => Self::parse_array(
|
||||
array,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
value_path,
|
||||
array_path,
|
||||
current_path,
|
||||
)?,
|
||||
Value::Object(object) => Self::parse_object(
|
||||
object,
|
||||
placeholder_string,
|
||||
repeat_string,
|
||||
value_path,
|
||||
array_path,
|
||||
current_path,
|
||||
)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn inject_value(rendered: &mut Value, injection_path: &Vec<PathComponent>, injected_value: Value) {
|
||||
let mut current_value = rendered;
|
||||
for injection_component in injection_path {
|
||||
current_value = match injection_component {
|
||||
PathComponent::MapKey(key) => current_value.get_mut(key).unwrap(),
|
||||
PathComponent::ArrayIndex(index) => current_value.get_mut(index).unwrap(),
|
||||
}
|
||||
}
|
||||
*current_value = injected_value;
|
||||
}
|
||||
|
||||
fn format_value(value: &Value) -> String {
|
||||
match value {
|
||||
Value::Array(array) => format!("an array of size {}", array.len()),
|
||||
Value::Object(object) => {
|
||||
format!("an object with {} field(s)", object.len())
|
||||
}
|
||||
value => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_value<T>(
|
||||
extraction_path: &[PathComponent],
|
||||
initial_value: &mut Value,
|
||||
) -> Result<T, ExtractionErrorKind>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
let mut current_value = initial_value;
|
||||
for (path_index, extraction_component) in extraction_path.iter().enumerate() {
|
||||
current_value = {
|
||||
match extraction_component {
|
||||
PathComponent::MapKey(key) => {
|
||||
if !current_value.is_object() {
|
||||
return Err(ExtractionErrorKind::WrongPathComponent {
|
||||
wrong_component: format_value(current_value),
|
||||
index: path_index,
|
||||
path: extraction_path.to_vec(),
|
||||
});
|
||||
}
|
||||
if let Some(object) = current_value.as_object_mut() {
|
||||
if !object.contains_key(key) {
|
||||
let typos =
|
||||
levenshtein_automata::LevenshteinAutomatonBuilder::new(2, true)
|
||||
.build_dfa(key);
|
||||
let mut key_suggestion = None;
|
||||
'check_typos: for (key, _) in object.iter() {
|
||||
match typos.eval(key) {
|
||||
levenshtein_automata::Distance::Exact(0) => { /* ??? */ }
|
||||
levenshtein_automata::Distance::Exact(_) => {
|
||||
key_suggestion = Some(key.to_owned());
|
||||
break 'check_typos;
|
||||
}
|
||||
levenshtein_automata::Distance::AtLeast(_) => continue,
|
||||
}
|
||||
}
|
||||
return Err(ExtractionErrorKind::MissingPathComponent {
|
||||
missing_index: path_index,
|
||||
path: extraction_path.to_vec(),
|
||||
key_suggestion,
|
||||
});
|
||||
}
|
||||
if let Some(value) = object.get_mut(key) {
|
||||
value
|
||||
} else {
|
||||
// borrow checking limit: the borrow checker cannot be convinced that `object` is no longer mutably borrowed on the
|
||||
// `else` branch of the `if let`, so we cannot return MissingPathComponent here.
|
||||
// As a workaround, we checked that the object does not contain the key above, making this `else` unreachable.
|
||||
unreachable!()
|
||||
}
|
||||
} else {
|
||||
// borrow checking limit: the borrow checker cannot be convinced that `current_value` is no longer mutably borrowed
|
||||
// on the `else` branch of the `if let`, so we cannot return WrongPathComponent here.
|
||||
// As a workaround, we checked that the value was not a map above, making this `else` unreachable.
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
PathComponent::ArrayIndex(index) => {
|
||||
if !current_value.is_array() {
|
||||
return Err(ExtractionErrorKind::WrongPathComponent {
|
||||
wrong_component: format_value(current_value),
|
||||
index: path_index,
|
||||
path: extraction_path.to_vec(),
|
||||
});
|
||||
}
|
||||
match current_value.get_mut(index) {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
return Err(ExtractionErrorKind::MissingPathComponent {
|
||||
missing_index: path_index,
|
||||
path: extraction_path.to_vec(),
|
||||
key_suggestion: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
serde_json::from_value(current_value.take()).map_err(|error| {
|
||||
ExtractionErrorKind::DeserializationError { error, path: extraction_path.to_vec() }
|
||||
})
|
||||
}
|
||||
|
||||
trait ExtractionResultErrorContext<T> {
|
||||
fn with_context<F>(self, f: F) -> Result<T, ExtractionError>
|
||||
where
|
||||
F: FnOnce(ExtractionErrorKind) -> ExtractionError;
|
||||
}
|
||||
|
||||
impl<T> ExtractionResultErrorContext<T> for Result<T, ExtractionErrorKind> {
|
||||
fn with_context<F>(self, f: F) -> Result<T, ExtractionError>
|
||||
where
|
||||
F: FnOnce(ExtractionErrorKind) -> ExtractionError,
|
||||
{
|
||||
match self {
|
||||
Ok(t) => Ok(t),
|
||||
Err(kind) => Err(f(kind)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::{PathComponent, TemplateParsingError, ValueTemplate};
|
||||
|
||||
fn new_template(template: Value) -> Result<ValueTemplate, TemplateParsingError> {
|
||||
ValueTemplate::new(template, "{{text}}", "{{..}}")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_template() {
|
||||
let template = json!({
|
||||
"toto": "no template at all",
|
||||
"titi": ["this", "will", "not", "work"],
|
||||
"tutu": null
|
||||
});
|
||||
|
||||
let error = new_template(template.clone()).unwrap_err();
|
||||
assert!(matches!(error, TemplateParsingError::MissingPlaceholderString))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn single_template() {
|
||||
let template = json!({
|
||||
"toto": "text",
|
||||
"titi": ["this", "will", "still", "{{text}}"],
|
||||
"tutu": null
|
||||
});
|
||||
|
||||
let basic = new_template(template.clone()).unwrap();
|
||||
|
||||
assert!(!basic.has_array_value());
|
||||
|
||||
assert_eq!(
|
||||
basic.inject(vec!["work".into(), Value::Null, "test".into()]).unwrap(),
|
||||
json!({
|
||||
"toto": "text",
|
||||
"titi": ["this", "will", "still", "work"],
|
||||
"tutu": null
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn too_many_placeholders() {
|
||||
let template = json!({
|
||||
"toto": "{{text}}",
|
||||
"titi": ["this", "will", "still", "{{text}}"],
|
||||
"tutu": "text"
|
||||
});
|
||||
|
||||
match new_template(template.clone()) {
|
||||
Err(TemplateParsingError::MultiplePlaceholderString(left, right)) => {
|
||||
assert_eq!(
|
||||
left,
|
||||
vec![PathComponent::MapKey("titi".into()), PathComponent::ArrayIndex(3)]
|
||||
);
|
||||
|
||||
assert_eq!(right, vec![PathComponent::MapKey("toto".into())])
|
||||
}
|
||||
_ => panic!("should error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_template() {
|
||||
let template = json!({
|
||||
"toto": "text",
|
||||
"titi": [{
|
||||
"type": "text",
|
||||
"data": "{{text}}"
|
||||
}, "{{..}}"],
|
||||
"tutu": null
|
||||
});
|
||||
|
||||
let basic = new_template(template.clone()).unwrap();
|
||||
|
||||
assert!(basic.has_array_value());
|
||||
|
||||
let injected_values = vec![
|
||||
"work".into(),
|
||||
Value::Null,
|
||||
42.into(),
|
||||
"test".into(),
|
||||
"tata".into(),
|
||||
"titi".into(),
|
||||
"tutu".into(),
|
||||
];
|
||||
|
||||
let rendered = basic.inject(injected_values.clone()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
rendered,
|
||||
json!({
|
||||
"toto": "text",
|
||||
"titi": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": "work"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": Value::Null
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": 42
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": "test"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": "tata"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": "titi"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"data": "tutu"
|
||||
}
|
||||
],
|
||||
"tutu": null
|
||||
})
|
||||
);
|
||||
|
||||
let extracted_values: Vec<Value> = basic.extract(rendered).unwrap();
|
||||
assert_eq!(extracted_values, injected_values);
|
||||
}
|
||||
}
|
40
crates/milli/src/vector/manual.rs
Normal file
40
crates/milli/src/vector/manual.rs
Normal file
|
@ -0,0 +1,40 @@
|
|||
use super::error::EmbedError;
|
||||
use super::{DistributionShift, Embeddings};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Embedder {
|
||||
dimensions: usize,
|
||||
distribution: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub dimensions: usize,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Self {
|
||||
Self { dimensions: options.dimensions, distribution: options.distribution }
|
||||
}
|
||||
|
||||
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let Some(text) = texts.pop() else { return Ok(Default::default()) };
|
||||
Err(EmbedError::embed_on_manual_embedder(text.chars().take(250).collect()))
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.distribution
|
||||
}
|
||||
}
|
606
crates/milli/src/vector/mod.rs
Normal file
606
crates/milli/src/vector/mod.rs
Normal file
|
@ -0,0 +1,606 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arroy::distances::{Angular, BinaryQuantizedAngular};
|
||||
use arroy::ItemId;
|
||||
use deserr::{DeserializeError, Deserr};
|
||||
use heed::{RoTxn, RwTxn, Unspecified};
|
||||
use ordered_float::OrderedFloat;
|
||||
use roaring::RoaringBitmap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use self::error::{EmbedError, NewEmbedderError};
|
||||
use crate::prompt::{Prompt, PromptData};
|
||||
use crate::ThreadPoolNoAbort;
|
||||
|
||||
pub mod error;
|
||||
pub mod hf;
|
||||
pub mod json_template;
|
||||
pub mod manual;
|
||||
pub mod openai;
|
||||
pub mod parsed_vectors;
|
||||
pub mod settings;
|
||||
|
||||
pub mod ollama;
|
||||
pub mod rest;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
||||
pub const REQUEST_PARALLELISM: usize = 40;
|
||||
|
||||
pub struct ArroyWrapper {
|
||||
quantized: bool,
|
||||
index: u16,
|
||||
database: arroy::Database<Unspecified>,
|
||||
}
|
||||
|
||||
impl ArroyWrapper {
|
||||
pub fn new(database: arroy::Database<Unspecified>, index: u16, quantized: bool) -> Self {
|
||||
Self { database, index, quantized }
|
||||
}
|
||||
|
||||
pub fn index(&self) -> u16 {
|
||||
self.index
|
||||
}
|
||||
|
||||
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
|
||||
if self.quantized {
|
||||
Ok(arroy::Reader::open(rtxn, self.index, self.quantized_db())?.dimensions())
|
||||
} else {
|
||||
Ok(arroy::Reader::open(rtxn, self.index, self.angular_db())?.dimensions())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn quantize(
|
||||
&mut self,
|
||||
wtxn: &mut RwTxn,
|
||||
index: u16,
|
||||
dimension: usize,
|
||||
) -> Result<(), arroy::Error> {
|
||||
if !self.quantized {
|
||||
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
|
||||
writer.prepare_changing_distance::<BinaryQuantizedAngular>(wtxn)?;
|
||||
self.quantized = true;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn need_build(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).need_build(rtxn)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).need_build(rtxn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build<R: rand::Rng + rand::SeedableRng>(
|
||||
&self,
|
||||
wtxn: &mut RwTxn,
|
||||
rng: &mut R,
|
||||
dimension: usize,
|
||||
) -> Result<(), arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).build(wtxn, rng, None)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).build(wtxn, rng, None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item(
|
||||
&self,
|
||||
wtxn: &mut RwTxn,
|
||||
dimension: usize,
|
||||
item_id: arroy::ItemId,
|
||||
vector: &[f32],
|
||||
) -> Result<(), arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension)
|
||||
.add_item(wtxn, item_id, vector)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension)
|
||||
.add_item(wtxn, item_id, vector)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn del_item(
|
||||
&self,
|
||||
wtxn: &mut RwTxn,
|
||||
dimension: usize,
|
||||
item_id: arroy::ItemId,
|
||||
) -> Result<bool, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).del_item(wtxn, item_id)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).del_item(wtxn, item_id)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).clear(wtxn)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).clear(wtxn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).is_empty(rtxn)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).is_empty(rtxn)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains_item(
|
||||
&self,
|
||||
rtxn: &RoTxn,
|
||||
dimension: usize,
|
||||
item: arroy::ItemId,
|
||||
) -> Result<bool, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Writer::new(self.quantized_db(), self.index, dimension).contains_item(rtxn, item)
|
||||
} else {
|
||||
arroy::Writer::new(self.angular_db(), self.index, dimension).contains_item(rtxn, item)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nns_by_item(
|
||||
&self,
|
||||
rtxn: &RoTxn,
|
||||
item: ItemId,
|
||||
limit: usize,
|
||||
filter: Option<&RoaringBitmap>,
|
||||
) -> Result<Option<Vec<(ItemId, f32)>>, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Reader::open(rtxn, self.index, self.quantized_db())?
|
||||
.nns_by_item(rtxn, item, limit, None, None, filter)
|
||||
} else {
|
||||
arroy::Reader::open(rtxn, self.index, self.angular_db())?
|
||||
.nns_by_item(rtxn, item, limit, None, None, filter)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nns_by_vector(
|
||||
&self,
|
||||
txn: &RoTxn,
|
||||
item: &[f32],
|
||||
limit: usize,
|
||||
filter: Option<&RoaringBitmap>,
|
||||
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Reader::open(txn, self.index, self.quantized_db())?
|
||||
.nns_by_vector(txn, item, limit, None, None, filter)
|
||||
} else {
|
||||
arroy::Reader::open(txn, self.index, self.angular_db())?
|
||||
.nns_by_vector(txn, item, limit, None, None, filter)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn item_vector(&self, rtxn: &RoTxn, docid: u32) -> Result<Option<Vec<f32>>, arroy::Error> {
|
||||
if self.quantized {
|
||||
arroy::Reader::open(rtxn, self.index, self.quantized_db())?.item_vector(rtxn, docid)
|
||||
} else {
|
||||
arroy::Reader::open(rtxn, self.index, self.angular_db())?.item_vector(rtxn, docid)
|
||||
}
|
||||
}
|
||||
|
||||
fn angular_db(&self) -> arroy::Database<Angular> {
|
||||
self.database.remap_data_type()
|
||||
}
|
||||
|
||||
fn quantized_db(&self) -> arroy::Database<BinaryQuantizedAngular> {
|
||||
self.database.remap_data_type()
|
||||
}
|
||||
}
|
||||
|
||||
/// One or multiple embeddings stored consecutively in a flat vector.
|
||||
pub struct Embeddings<F> {
|
||||
data: Vec<F>,
|
||||
dimension: usize,
|
||||
}
|
||||
|
||||
impl<F> Embeddings<F> {
|
||||
/// Declares an empty vector of embeddings of the specified dimensions.
|
||||
pub fn new(dimension: usize) -> Self {
|
||||
Self { data: Default::default(), dimension }
|
||||
}
|
||||
|
||||
/// Declares a vector of embeddings containing a single element.
|
||||
///
|
||||
/// The dimension is inferred from the length of the passed embedding.
|
||||
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
|
||||
Self { dimension: embedding.len(), data: embedding }
|
||||
}
|
||||
|
||||
/// Declares a vector of embeddings from its components.
|
||||
///
|
||||
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
|
||||
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
|
||||
let mut this = Self::new(dimension);
|
||||
this.append(data)?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Returns the number of embeddings in this vector of embeddings.
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.data.len() / self.dimension
|
||||
}
|
||||
|
||||
/// Dimension of a single embedding.
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.dimension
|
||||
}
|
||||
|
||||
/// Deconstructs self into the inner flat vector.
|
||||
pub fn into_inner(self) -> Vec<F> {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// A reference to the inner flat vector.
|
||||
pub fn as_inner(&self) -> &[F] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Iterates over the embeddings contained in the flat vector.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
|
||||
self.data.as_slice().chunks_exact(self.dimension)
|
||||
}
|
||||
|
||||
/// Push an embedding at the end of the embeddings.
|
||||
///
|
||||
/// If `embedding.len() != self.dimension`, then the push operation fails.
|
||||
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embedding.len() != self.dimension {
|
||||
return Err(embedding);
|
||||
}
|
||||
self.data.append(&mut embedding);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append a flat vector of embeddings a the end of the embeddings.
|
||||
///
|
||||
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
|
||||
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embeddings.len() % self.dimension != 0 {
|
||||
return Err(embeddings);
|
||||
}
|
||||
self.data.append(&mut embeddings);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// An embedder can be used to transform text into embeddings.
|
||||
#[derive(Debug)]
|
||||
pub enum Embedder {
|
||||
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
||||
HuggingFace(hf::Embedder),
|
||||
/// An embedder based on making embedding queries against the OpenAI API.
|
||||
OpenAi(openai::Embedder),
|
||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||
UserProvided(manual::Embedder),
|
||||
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
|
||||
Ollama(ollama::Embedder),
|
||||
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
|
||||
Rest(rest::Embedder),
|
||||
}
|
||||
|
||||
/// Configuration for an embedder.
|
||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// Options of the embedder, specific to each kind of embedder
|
||||
pub embedder_options: EmbedderOptions,
|
||||
/// Document template
|
||||
pub prompt: PromptData,
|
||||
/// If this embedder is binary quantized
|
||||
pub quantized: Option<bool>,
|
||||
// TODO: add metrics and anything needed
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
pub fn quantized(&self) -> bool {
|
||||
self.quantized.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Map of embedder configurations.
|
||||
///
|
||||
/// Each configuration is mapped to a name.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
|
||||
|
||||
impl EmbeddingConfigs {
|
||||
/// Create the map from its internal component.s
|
||||
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Get an embedder configuration and template from its name.
|
||||
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
|
||||
self.0.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for EmbeddingConfigs {
|
||||
type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
|
||||
|
||||
type IntoIter =
|
||||
std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// Options of an embedder, specific to each kind of embedder.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub enum EmbedderOptions {
|
||||
HuggingFace(hf::EmbedderOptions),
|
||||
OpenAi(openai::EmbedderOptions),
|
||||
Ollama(ollama::EmbedderOptions),
|
||||
UserProvided(manual::EmbedderOptions),
|
||||
Rest(rest::EmbedderOptions),
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self::HuggingFace(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
/// Spawns a new embedder built from its options.
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
Ok(match options {
|
||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||
EmbedderOptions::UserProvided(options) => {
|
||||
Self::UserProvided(manual::Embedder::new(options))
|
||||
}
|
||||
EmbedderOptions::Rest(options) => {
|
||||
Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed one or multiple texts.
|
||||
///
|
||||
/// Each text can be embedded as one or multiple embeddings.
|
||||
pub fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||
Embedder::Ollama(embedder) => embedder.embed(texts),
|
||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||
Embedder::Rest(embedder) => embedder.embed(texts),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
|
||||
let mut embeddings = self.embed(vec![text])?;
|
||||
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
|
||||
Ok(if embeddings.iter().nth(1).is_some() {
|
||||
tracing::warn!("Ignoring embeddings past the first one in long search query");
|
||||
embeddings.iter().next().unwrap().to_vec()
|
||||
} else {
|
||||
embeddings.into_inner()
|
||||
})
|
||||
}
|
||||
|
||||
/// Embed multiple chunks of texts.
|
||||
///
|
||||
/// Each chunk is composed of one or multiple texts.
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &ThreadPoolNoAbort,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the dimensions of a single embedding produced by the embedder.
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.dimensions(),
|
||||
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
||||
Embedder::Ollama(embedder) => embedder.dimensions(),
|
||||
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
||||
Embedder::Rest(embedder) => embedder.dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.distribution(),
|
||||
Embedder::OpenAi(embedder) => embedder.distribution(),
|
||||
Embedder::Ollama(embedder) => embedder.distribution(),
|
||||
Embedder::UserProvided(embedder) => embedder.distribution(),
|
||||
Embedder::Rest(embedder) => embedder.distribution(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn uses_document_template(&self) -> bool {
|
||||
match self {
|
||||
Embedder::HuggingFace(_)
|
||||
| Embedder::OpenAi(_)
|
||||
| Embedder::Ollama(_)
|
||||
| Embedder::Rest(_) => true,
|
||||
Embedder::UserProvided(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
|
||||
///
|
||||
/// The intended use is to make the similarity score more comparable to the regular ranking score.
|
||||
/// This allows to correct effects where results are too "packed" around a certain value.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
|
||||
#[serde(from = "DistributionShiftSerializable")]
|
||||
#[serde(into = "DistributionShiftSerializable")]
|
||||
pub struct DistributionShift {
|
||||
/// Value where the results are "packed".
|
||||
///
|
||||
/// Similarity scores are translated so that they are packed around 0.5 instead
|
||||
pub current_mean: OrderedFloat<f32>,
|
||||
|
||||
/// standard deviation of a similarity score.
|
||||
///
|
||||
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
|
||||
pub current_sigma: OrderedFloat<f32>,
|
||||
}
|
||||
|
||||
impl<E> Deserr<E> for DistributionShift
|
||||
where
|
||||
E: DeserializeError,
|
||||
{
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef<'_>,
|
||||
) -> Result<Self, E> {
|
||||
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
|
||||
if value.mean < 0. || value.mean > 1. {
|
||||
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
|
||||
None,
|
||||
deserr::ErrorKind::Unexpected {
|
||||
msg: format!(
|
||||
"the distribution mean must be in the range [0, 1], got {}",
|
||||
value.mean
|
||||
),
|
||||
},
|
||||
location,
|
||||
)));
|
||||
}
|
||||
if value.sigma <= 0. || value.sigma > 1. {
|
||||
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
|
||||
None,
|
||||
deserr::ErrorKind::Unexpected {
|
||||
msg: format!(
|
||||
"the distribution sigma must be in the range ]0, 1], got {}",
|
||||
value.sigma
|
||||
),
|
||||
},
|
||||
location,
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(value.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Deserr)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
#[deserr(deny_unknown_fields)]
|
||||
struct DistributionShiftSerializable {
|
||||
mean: f32,
|
||||
sigma: f32,
|
||||
}
|
||||
|
||||
impl From<DistributionShift> for DistributionShiftSerializable {
|
||||
fn from(
|
||||
DistributionShift {
|
||||
current_mean: OrderedFloat(current_mean),
|
||||
current_sigma: OrderedFloat(current_sigma),
|
||||
}: DistributionShift,
|
||||
) -> Self {
|
||||
Self { mean: current_mean, sigma: current_sigma }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DistributionShiftSerializable> for DistributionShift {
|
||||
fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
|
||||
Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
|
||||
}
|
||||
}
|
||||
|
||||
impl DistributionShift {
|
||||
/// `None` if sigma <= 0.
|
||||
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
|
||||
if sigma <= 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shift(&self, score: f32) -> f32 {
|
||||
let current_mean = self.current_mean.0;
|
||||
let current_sigma = self.current_sigma.0;
|
||||
// <https://math.stackexchange.com/a/2894689>
|
||||
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
||||
// The parameters we're given is the mean and sigma of the native result distribution.
|
||||
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
|
||||
|
||||
let target_mean = 0.5;
|
||||
let target_sigma = 0.4;
|
||||
|
||||
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
||||
let factor = target_sigma / current_sigma;
|
||||
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
||||
let offset = target_mean - (factor * current_mean);
|
||||
|
||||
let mut score = factor * score + offset;
|
||||
|
||||
// clamp the final score in the ]0, 1] interval.
|
||||
if score <= 0.0 {
|
||||
score = f32::EPSILON;
|
||||
}
|
||||
if score > 1.0 {
|
||||
score = 1.0;
|
||||
}
|
||||
|
||||
score
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether CUDA is supported in this version of Meilisearch.
|
||||
pub const fn is_cuda_enabled() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
||||
pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
|
||||
let embedder_id = (embedder_id as u16) << 8;
|
||||
|
||||
(0..=u8::MAX).map(move |k| embedder_id | (k as u16))
|
||||
}
|
123
crates/milli/src/vector/ollama.rs
Normal file
123
crates/milli/src/vector/ollama.rs
Normal file
|
@ -0,0 +1,123 @@
|
|||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||
|
||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||
use super::{DistributionShift, Embeddings};
|
||||
use crate::error::FaultSource;
|
||||
use crate::ThreadPoolNoAbort;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
rest_embedder: RestEmbedder,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub embedding_model: String,
|
||||
pub url: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
pub dimensions: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn with_default_model(
|
||||
api_key: Option<String>,
|
||||
url: Option<String>,
|
||||
dimensions: Option<usize>,
|
||||
) -> Self {
|
||||
Self {
|
||||
embedding_model: "nomic-embed-text".into(),
|
||||
api_key,
|
||||
url,
|
||||
distribution: None,
|
||||
dimensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let model = options.embedding_model.as_str();
|
||||
let rest_embedder = match RestEmbedder::new(
|
||||
RestEmbedderOptions {
|
||||
api_key: options.api_key,
|
||||
dimensions: options.dimensions,
|
||||
distribution: options.distribution,
|
||||
url: options.url.unwrap_or_else(get_ollama_path),
|
||||
request: serde_json::json!({
|
||||
"model": model,
|
||||
"prompt": super::rest::REQUEST_PLACEHOLDER,
|
||||
}),
|
||||
response: serde_json::json!({
|
||||
"embedding": super::rest::RESPONSE_PLACEHOLDER,
|
||||
}),
|
||||
headers: Default::default(),
|
||||
},
|
||||
super::rest::ConfigurationSource::Ollama,
|
||||
) {
|
||||
Ok(embedder) => embedder,
|
||||
Err(NewEmbedderError {
|
||||
kind:
|
||||
NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
|
||||
kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
|
||||
fault: _,
|
||||
}),
|
||||
fault: _,
|
||||
}) => {
|
||||
return Err(NewEmbedderError::could_not_determine_dimension(
|
||||
EmbedError::ollama_model_not_found(error),
|
||||
))
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
|
||||
Ok(Self { rest_embedder })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self.rest_embedder.embed(texts) {
|
||||
Ok(embeddings) => Ok(embeddings),
|
||||
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
||||
Err(EmbedError::ollama_model_not_found(error))
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &ThreadPoolNoAbort,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
threads
|
||||
.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
.map_err(|error| EmbedError {
|
||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||
fault: FaultSource::Bug,
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
self.rest_embedder.chunk_count_hint()
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.rest_embedder.dimensions()
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.rest_embedder.distribution()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_ollama_path() -> String {
|
||||
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
||||
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
||||
}
|
274
crates/milli/src/vector/openai.rs
Normal file
274
crates/milli/src/vector/openai.rs
Normal file
|
@ -0,0 +1,274 @@
|
|||
use ordered_float::OrderedFloat;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||
|
||||
use super::error::{EmbedError, NewEmbedderError};
|
||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||
use super::{DistributionShift, Embeddings};
|
||||
use crate::error::FaultSource;
|
||||
use crate::vector::error::EmbedErrorKind;
|
||||
use crate::ThreadPoolNoAbort;
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub url: Option<String>,
|
||||
pub api_key: Option<String>,
|
||||
pub embedding_model: EmbeddingModel,
|
||||
pub dimensions: Option<usize>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn dimensions(&self) -> usize {
|
||||
if self.embedding_model.supports_overriding_dimensions() {
|
||||
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
|
||||
} else {
|
||||
self.embedding_model.default_dimensions()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request(&self) -> serde_json::Value {
|
||||
let model = self.embedding_model.name();
|
||||
|
||||
let mut request = serde_json::json!({
|
||||
"model": model,
|
||||
"input": [super::rest::REQUEST_PLACEHOLDER, super::rest::REPEAT_PLACEHOLDER]
|
||||
});
|
||||
|
||||
if self.embedding_model.supports_overriding_dimensions() {
|
||||
if let Some(dimensions) = self.dimensions {
|
||||
request["dimensions"] = dimensions.into();
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.distribution.or(self.embedding_model.distribution())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
Default,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
serde::Serialize,
|
||||
serde::Deserialize,
|
||||
deserr::Deserr,
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum EmbeddingModel {
|
||||
// # WARNING
|
||||
//
|
||||
// If ever adding a model, make sure to add it to the list of supported models below.
|
||||
#[serde(rename = "text-embedding-ada-002")]
|
||||
#[deserr(rename = "text-embedding-ada-002")]
|
||||
TextEmbeddingAda002,
|
||||
|
||||
#[default]
|
||||
#[serde(rename = "text-embedding-3-small")]
|
||||
#[deserr(rename = "text-embedding-3-small")]
|
||||
TextEmbedding3Small,
|
||||
|
||||
#[serde(rename = "text-embedding-3-large")]
|
||||
#[deserr(rename = "text-embedding-3-large")]
|
||||
TextEmbedding3Large,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn supported_models() -> &'static [&'static str] {
|
||||
&["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
|
||||
}
|
||||
|
||||
pub fn max_token(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => 8191,
|
||||
EmbeddingModel::TextEmbedding3Large => 8191,
|
||||
EmbeddingModel::TextEmbedding3Small => 8191,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_dimensions(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => 1536,
|
||||
EmbeddingModel::TextEmbedding3Large => 3072,
|
||||
EmbeddingModel::TextEmbedding3Small => 1536,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
|
||||
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
|
||||
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &str) -> Option<Self> {
|
||||
match name {
|
||||
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
|
||||
"text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
|
||||
"text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn distribution(&self) -> Option<DistributionShift> {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.90),
|
||||
current_sigma: OrderedFloat(0.08),
|
||||
}),
|
||||
EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.70),
|
||||
current_sigma: OrderedFloat(0.1),
|
||||
}),
|
||||
EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.75),
|
||||
current_sigma: OrderedFloat(0.1),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supports_overriding_dimensions(&self) -> bool {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => false,
|
||||
EmbeddingModel::TextEmbedding3Large => true,
|
||||
EmbeddingModel::TextEmbedding3Small => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn with_default_model(api_key: Option<String>) -> Self {
|
||||
Self {
|
||||
api_key,
|
||||
embedding_model: Default::default(),
|
||||
dimensions: None,
|
||||
distribution: None,
|
||||
url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_api_key() -> String {
|
||||
std::env::var("MEILI_OPENAI_API_KEY")
|
||||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
rest_embedder: RestEmbedder,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut inferred_api_key = Default::default();
|
||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||
inferred_api_key = infer_api_key();
|
||||
&inferred_api_key
|
||||
});
|
||||
|
||||
let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned();
|
||||
|
||||
let rest_embedder = RestEmbedder::new(
|
||||
RestEmbedderOptions {
|
||||
api_key: (!api_key.is_empty()).then(|| api_key.clone()),
|
||||
distribution: None,
|
||||
dimensions: Some(options.dimensions()),
|
||||
url,
|
||||
request: options.request(),
|
||||
response: serde_json::json!({
|
||||
"data": [{
|
||||
"embedding": super::rest::RESPONSE_PLACEHOLDER
|
||||
},
|
||||
super::rest::REPEAT_PLACEHOLDER
|
||||
]
|
||||
}),
|
||||
headers: Default::default(),
|
||||
},
|
||||
super::rest::ConfigurationSource::OpenAi,
|
||||
)?;
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, rest_embedder, tokenizer })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self.rest_embedder.embed_ref(&texts) {
|
||||
Ok(embeddings) => Ok(embeddings),
|
||||
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
||||
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||
self.try_embed_tokenized(&texts)
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
let max_token_count = self.options.embedding_model.max_token();
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let tokens = &encoded.as_slice()[0..max_token_count];
|
||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||
|
||||
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
||||
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
||||
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
|
||||
})?;
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &ThreadPoolNoAbort,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
threads
|
||||
.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
.map_err(|error| EmbedError {
|
||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||
fault: FaultSource::Bug,
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
self.rest_embedder.chunk_count_hint()
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.options.dimensions()
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.distribution()
|
||||
}
|
||||
}
|
417
crates/milli/src/vector/parsed_vectors.rs
Normal file
417
crates/milli/src/vector/parsed_vectors.rs
Normal file
|
@ -0,0 +1,417 @@
|
|||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use deserr::{take_cf_content, DeserializeError, Deserr, Sequence};
|
||||
use obkv::KvReader;
|
||||
use serde_json::{from_slice, Value};
|
||||
|
||||
use super::Embedding;
|
||||
use crate::index::IndexEmbeddingConfig;
|
||||
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
|
||||
use crate::{DocumentId, FieldId, InternalError, UserError};
|
||||
|
||||
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
||||
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum Vectors {
|
||||
ImplicitlyUserProvided(VectorOrArrayOfVectors),
|
||||
Explicit(ExplicitVectors),
|
||||
}
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for Vectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef<'_>,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Sequence(_) | deserr::Value::Null => {
|
||||
Ok(Vectors::ImplicitlyUserProvided(VectorOrArrayOfVectors::deserialize_from_value(
|
||||
value, location,
|
||||
)?))
|
||||
}
|
||||
deserr::Value::Map(_) => {
|
||||
Ok(Vectors::Explicit(ExplicitVectors::deserialize_from_value(value, location)?))
|
||||
}
|
||||
|
||||
value => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[
|
||||
deserr::ValueKind::Sequence,
|
||||
deserr::ValueKind::Map,
|
||||
deserr::ValueKind::Null,
|
||||
],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Vectors {
|
||||
pub fn must_regenerate(&self) -> bool {
|
||||
match self {
|
||||
Vectors::ImplicitlyUserProvided(_) => false,
|
||||
Vectors::Explicit(ExplicitVectors { regenerate, .. }) => *regenerate,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
||||
match self {
|
||||
Vectors::ImplicitlyUserProvided(embeddings) => {
|
||||
Some(embeddings.into_array_of_vectors().unwrap_or_default())
|
||||
}
|
||||
Vectors::Explicit(ExplicitVectors { embeddings, regenerate: _ }) => {
|
||||
embeddings.map(|embeddings| embeddings.into_array_of_vectors().unwrap_or_default())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Deserr, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ExplicitVectors {
|
||||
#[serde(default)]
|
||||
#[deserr(default)]
|
||||
pub embeddings: Option<VectorOrArrayOfVectors>,
|
||||
pub regenerate: bool,
|
||||
}
|
||||
|
||||
pub enum VectorState {
|
||||
Inline(Vectors),
|
||||
Manual,
|
||||
Generated,
|
||||
}
|
||||
|
||||
impl VectorState {
|
||||
pub fn must_regenerate(&self) -> bool {
|
||||
match self {
|
||||
VectorState::Inline(vectors) => vectors.must_regenerate(),
|
||||
VectorState::Manual => false,
|
||||
VectorState::Generated => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum VectorsState {
|
||||
NoVectorsFid,
|
||||
NoVectorsFieldInDocument,
|
||||
Vectors(BTreeMap<String, Vectors>),
|
||||
}
|
||||
|
||||
pub struct ParsedVectorsDiff {
|
||||
old: BTreeMap<String, VectorState>,
|
||||
new: VectorsState,
|
||||
}
|
||||
|
||||
impl ParsedVectorsDiff {
|
||||
pub fn new(
|
||||
docid: DocumentId,
|
||||
embedders_configs: &[IndexEmbeddingConfig],
|
||||
documents_diff: KvReader<'_, FieldId>,
|
||||
old_vectors_fid: Option<FieldId>,
|
||||
new_vectors_fid: Option<FieldId>,
|
||||
) -> Result<Self, Error> {
|
||||
let mut old = match old_vectors_fid
|
||||
.and_then(|vectors_fid| documents_diff.get(vectors_fid))
|
||||
.map(KvReaderDelAdd::new)
|
||||
.map(|obkv| to_vector_map(obkv, DelAdd::Deletion))
|
||||
.transpose()
|
||||
{
|
||||
Ok(del) => del,
|
||||
// ignore wrong shape for old version of documents, use an empty map in this case
|
||||
Err(Error::InvalidMap(value)) => {
|
||||
tracing::warn!(%value, "Previous version of the `_vectors` field had a wrong shape");
|
||||
Default::default()
|
||||
}
|
||||
Err(error) => {
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
.flatten().map_or(BTreeMap::default(), |del| del.into_iter().map(|(name, vec)| (name, VectorState::Inline(vec))).collect());
|
||||
for embedding_config in embedders_configs {
|
||||
if embedding_config.user_provided.contains(docid) {
|
||||
old.entry(embedding_config.name.to_string()).or_insert(VectorState::Manual);
|
||||
}
|
||||
}
|
||||
|
||||
let new = 'new: {
|
||||
let Some(new_vectors_fid) = new_vectors_fid else {
|
||||
break 'new VectorsState::NoVectorsFid;
|
||||
};
|
||||
let Some(bytes) = documents_diff.get(new_vectors_fid) else {
|
||||
break 'new VectorsState::NoVectorsFieldInDocument;
|
||||
};
|
||||
let obkv = KvReaderDelAdd::new(bytes);
|
||||
match to_vector_map(obkv, DelAdd::Addition)? {
|
||||
Some(new) => VectorsState::Vectors(new),
|
||||
None => VectorsState::NoVectorsFieldInDocument,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { old, new })
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, embedder_name: &str) -> (VectorState, VectorState) {
|
||||
let old = self.old.remove(embedder_name).unwrap_or(VectorState::Generated);
|
||||
let state_from_old = match old {
|
||||
// assume a userProvided is still userProvided
|
||||
VectorState::Manual => VectorState::Manual,
|
||||
// generated is still generated
|
||||
VectorState::Generated => VectorState::Generated,
|
||||
// weird case that shouldn't happen were the previous docs version is inline,
|
||||
// but it was removed in the new version
|
||||
// Since it is not in the new version, we switch to generated
|
||||
VectorState::Inline(_) => VectorState::Generated,
|
||||
};
|
||||
let new = match &mut self.new {
|
||||
VectorsState::Vectors(new) => {
|
||||
new.remove(embedder_name).map(VectorState::Inline).unwrap_or(state_from_old)
|
||||
}
|
||||
_ =>
|
||||
// if no `_vectors` field is present in the new document,
|
||||
// the state depends on the previous version of the document
|
||||
{
|
||||
state_from_old
|
||||
}
|
||||
};
|
||||
|
||||
(old, new)
|
||||
}
|
||||
|
||||
pub fn into_new_vectors_keys_iter(self) -> impl Iterator<Item = String> {
|
||||
let maybe_it = match self.new {
|
||||
VectorsState::NoVectorsFid => None,
|
||||
VectorsState::NoVectorsFieldInDocument => None,
|
||||
VectorsState::Vectors(vectors) => Some(vectors.into_keys()),
|
||||
};
|
||||
maybe_it.into_iter().flatten()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for ParsedVectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef<'_>,
|
||||
) -> Result<Self, E> {
|
||||
let value = <BTreeMap<String, Vectors>>::deserialize_from_value(value, location)?;
|
||||
Ok(ParsedVectors(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl ParsedVectors {
|
||||
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
|
||||
let value: serde_json::Value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
||||
deserr::deserialize(value).map_err(|error| Error::InvalidEmbedderConf { error })
|
||||
}
|
||||
|
||||
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
|
||||
self.0.retain(|k, _v| !embedders.contains(k))
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Error {
|
||||
InvalidMap(Value),
|
||||
InvalidEmbedderConf { error: deserr::errors::JsonError },
|
||||
InternalSerdeJson(serde_json::Error),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn to_crate_error(self, document_id: String) -> crate::Error {
|
||||
match self {
|
||||
Error::InvalidMap(value) => {
|
||||
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
|
||||
}
|
||||
Error::InvalidEmbedderConf { error } => {
|
||||
crate::Error::UserError(UserError::InvalidVectorsEmbedderConf {
|
||||
document_id,
|
||||
error,
|
||||
})
|
||||
}
|
||||
Error::InternalSerdeJson(error) => {
|
||||
crate::Error::InternalError(InternalError::SerdeJson(error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_vector_map(
|
||||
obkv: KvReaderDelAdd<'_>,
|
||||
side: DelAdd,
|
||||
) -> Result<Option<BTreeMap<String, Vectors>>, Error> {
|
||||
Ok(if let Some(value) = obkv.get(side) {
|
||||
let ParsedVectors(parsed_vectors) = ParsedVectors::from_bytes(value)?;
|
||||
Some(parsed_vectors)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Represents either a vector or an array of multiple vectors.
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(transparent)]
|
||||
pub struct VectorOrArrayOfVectors {
|
||||
#[serde(with = "either::serde_untagged_optional")]
|
||||
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
|
||||
}
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for VectorOrArrayOfVectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef<'_>,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Null => Ok(VectorOrArrayOfVectors { inner: None }),
|
||||
deserr::Value::Sequence(seq) => {
|
||||
let mut iter = seq.into_iter();
|
||||
match iter.next().map(|v| v.into_value()) {
|
||||
None => {
|
||||
// With the strange way serde serialize the `Either`, we must send the left part
|
||||
// otherwise it'll consider we returned [[]]
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(Vec::new())) })
|
||||
}
|
||||
Some(val @ deserr::Value::Sequence(_)) => {
|
||||
let first = Embedding::deserialize_from_value(val, location.push_index(0))?;
|
||||
let mut collect = vec![first];
|
||||
let mut tail = iter
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
Embedding::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_index(i + 1),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
collect.append(&mut tail);
|
||||
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(collect)) })
|
||||
}
|
||||
Some(
|
||||
val @ deserr::Value::Integer(_)
|
||||
| val @ deserr::Value::NegativeInteger(_)
|
||||
| val @ deserr::Value::Float(_),
|
||||
) => {
|
||||
let first = <f32>::deserialize_from_value(val, location.push_index(0))?;
|
||||
let mut embedding = iter
|
||||
.enumerate()
|
||||
.map(|(i, v)| {
|
||||
<f32>::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_index(i + 1),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
embedding.insert(0, first);
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Right(embedding)) })
|
||||
}
|
||||
Some(value) => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Float],
|
||||
},
|
||||
location.push_index(0),
|
||||
))),
|
||||
}
|
||||
}
|
||||
value => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Null],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorOrArrayOfVectors {
|
||||
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
||||
match self.inner? {
|
||||
either::Either::Left(vectors) => Some(vectors),
|
||||
either::Either::Right(vector) => Some(vec![vector]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_array_of_vectors(array_of_vec: Vec<Embedding>) -> Self {
|
||||
Self { inner: Some(either::Either::Left(array_of_vec)) }
|
||||
}
|
||||
|
||||
pub fn from_vector(vec: Embedding) -> Self {
|
||||
Self { inner: Some(either::Either::Right(vec)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Embedding> for VectorOrArrayOfVectors {
|
||||
fn from(vec: Embedding) -> Self {
|
||||
Self::from_vector(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<Embedding>> for VectorOrArrayOfVectors {
|
||||
fn from(vec: Vec<Embedding>) -> Self {
|
||||
Self::from_array_of_vectors(vec)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::VectorOrArrayOfVectors;
|
||||
|
||||
fn embedding_from_str(s: &str) -> Result<VectorOrArrayOfVectors, deserr::errors::JsonError> {
|
||||
let value: serde_json::Value = serde_json::from_str(s).unwrap();
|
||||
deserr::deserialize(value)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_of_vectors() {
|
||||
let null = embedding_from_str("null").unwrap();
|
||||
let empty = embedding_from_str("[]").unwrap();
|
||||
let one = embedding_from_str("[0.1]").unwrap();
|
||||
let two = embedding_from_str("[0.1, 0.2]").unwrap();
|
||||
let one_vec = embedding_from_str("[[0.1, 0.2]]").unwrap();
|
||||
let two_vecs = embedding_from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
||||
|
||||
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
|
||||
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
|
||||
insta::assert_json_snapshot!(one.into_array_of_vectors(), @r###"
|
||||
[
|
||||
[
|
||||
0.1
|
||||
]
|
||||
]
|
||||
"###);
|
||||
insta::assert_json_snapshot!(two.into_array_of_vectors(), @r###"
|
||||
[
|
||||
[
|
||||
0.1,
|
||||
0.2
|
||||
]
|
||||
]
|
||||
"###);
|
||||
insta::assert_json_snapshot!(one_vec.into_array_of_vectors(), @r###"
|
||||
[
|
||||
[
|
||||
0.1,
|
||||
0.2
|
||||
]
|
||||
]
|
||||
"###);
|
||||
insta::assert_json_snapshot!(two_vecs.into_array_of_vectors(), @r###"
|
||||
[
|
||||
[
|
||||
0.1,
|
||||
0.2
|
||||
],
|
||||
[
|
||||
0.3,
|
||||
0.4
|
||||
]
|
||||
]
|
||||
"###);
|
||||
}
|
||||
}
|
411
crates/milli/src/vector/rest.rs
Normal file
411
crates/milli/src/vector/rest.rs
Normal file
|
@ -0,0 +1,411 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use deserr::Deserr;
|
||||
use rand::Rng;
|
||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::EmbedErrorKind;
|
||||
use super::json_template::ValueTemplate;
|
||||
use super::{
|
||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||
};
|
||||
use crate::error::FaultSource;
|
||||
use crate::ThreadPoolNoAbort;
|
||||
|
||||
// retrying in case of failure
|
||||
pub struct Retry {
|
||||
pub error: EmbedError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConfigurationSource {
|
||||
OpenAi,
|
||||
Ollama,
|
||||
User,
|
||||
}
|
||||
|
||||
pub enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
RetryTokenized,
|
||||
RetryAfterRateLimit,
|
||||
}
|
||||
|
||||
impl Retry {
|
||||
pub fn give_up(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
pub fn retry_later(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
pub fn retry_tokenized(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||
}
|
||||
|
||||
pub fn rate_limited(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||
}
|
||||
|
||||
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
|
||||
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
|
||||
RetryStrategy::RetryAfterRateLimit => {
|
||||
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn must_tokenize(&self) -> bool {
|
||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||
}
|
||||
|
||||
pub fn into_error(self) -> EmbedError {
|
||||
self.error
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
data: EmbedderData,
|
||||
dimensions: usize,
|
||||
distribution: Option<DistributionShift>,
|
||||
}
|
||||
|
||||
/// All data needed to perform requests and parse responses
|
||||
#[derive(Debug)]
|
||||
struct EmbedderData {
|
||||
client: ureq::Agent,
|
||||
bearer: Option<String>,
|
||||
headers: BTreeMap<String, String>,
|
||||
url: String,
|
||||
request: Request,
|
||||
response: Response,
|
||||
configuration_source: ConfigurationSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub api_key: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
pub dimensions: Option<usize>,
|
||||
pub url: String,
|
||||
pub request: serde_json::Value,
|
||||
pub response: serde_json::Value,
|
||||
pub headers: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
impl std::hash::Hash for EmbedderOptions {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.api_key.hash(state);
|
||||
self.distribution.hash(state);
|
||||
self.dimensions.hash(state);
|
||||
self.url.hash(state);
|
||||
// skip hashing the request and response
|
||||
// collisions in regular usage should be minimal,
|
||||
// and the list is limited to 256 values anyway
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
enum InputType {
|
||||
Text,
|
||||
TextArray,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(
|
||||
options: EmbedderOptions,
|
||||
configuration_source: ConfigurationSource,
|
||||
) -> Result<Self, NewEmbedderError> {
|
||||
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
||||
|
||||
let client = ureq::AgentBuilder::new()
|
||||
.max_idle_connections(REQUEST_PARALLELISM * 2)
|
||||
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
||||
.build();
|
||||
|
||||
let request = Request::new(options.request)?;
|
||||
let response = Response::new(options.response, &request)?;
|
||||
|
||||
let data = EmbedderData {
|
||||
client,
|
||||
bearer,
|
||||
url: options.url,
|
||||
request,
|
||||
response,
|
||||
configuration_source,
|
||||
headers: options.headers,
|
||||
};
|
||||
|
||||
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||
dimensions
|
||||
} else {
|
||||
infer_dimensions(&data)?
|
||||
};
|
||||
|
||||
Ok(Self { data, dimensions, distribution: options.distribution })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
|
||||
}
|
||||
|
||||
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||
where
|
||||
S: AsRef<str> + Serialize,
|
||||
{
|
||||
embed(&self.data, texts, texts.len(), Some(self.dimensions))
|
||||
}
|
||||
|
||||
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
||||
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
|
||||
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||
Ok(embeddings.pop().unwrap())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &ThreadPoolNoAbort,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
threads
|
||||
.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
.map_err(|error| EmbedError {
|
||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||
fault: FaultSource::Bug,
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
super::REQUEST_PARALLELISM
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
match self.data.request.input_type() {
|
||||
InputType::Text => 1,
|
||||
InputType::TextArray => 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.distribution
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||
let v = embed(data, ["test"].as_slice(), 1, None)
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||
Ok(v.first().unwrap().dimension())
|
||||
}
|
||||
|
||||
fn embed<S>(
|
||||
data: &EmbedderData,
|
||||
inputs: &[S],
|
||||
expected_count: usize,
|
||||
expected_dimension: Option<usize>,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||
where
|
||||
S: Serialize,
|
||||
{
|
||||
let request = data.client.post(&data.url);
|
||||
let request = if let Some(bearer) = &data.bearer {
|
||||
request.set("Authorization", bearer)
|
||||
} else {
|
||||
request
|
||||
};
|
||||
let mut request = request.set("Content-Type", "application/json");
|
||||
for (header, value) in &data.headers {
|
||||
request = request.set(header.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let body = data.request.inject_texts(inputs);
|
||||
|
||||
for attempt in 0..10 {
|
||||
let response = request.clone().send_json(&body);
|
||||
let result = check_response(response, data.configuration_source);
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(response) => {
|
||||
return response_to_embedding(response, data, expected_count, expected_dimension)
|
||||
}
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||
|
||||
// randomly up to double the retry duration
|
||||
let retry_duration = retry_duration
|
||||
+ rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
|
||||
|
||||
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
|
||||
std::thread::sleep(retry_duration);
|
||||
}
|
||||
|
||||
let response = request.send_json(&body);
|
||||
let result = check_response(response, data.configuration_source);
|
||||
result.map_err(Retry::into_error).and_then(|response| {
|
||||
response_to_embedding(response, data, expected_count, expected_dimension)
|
||||
})
|
||||
}
|
||||
|
||||
fn check_response(
|
||||
response: Result<ureq::Response, ureq::Error>,
|
||||
configuration_source: ConfigurationSource,
|
||||
) -> Result<ureq::Response, Retry> {
|
||||
match response {
|
||||
Ok(response) => Ok(response),
|
||||
Err(ureq::Error::Status(code, response)) => {
|
||||
let error_response: Option<String> = response.into_string().ok();
|
||||
Err(match code {
|
||||
401 => Retry::give_up(EmbedError::rest_unauthorized(
|
||||
error_response,
|
||||
configuration_source,
|
||||
)),
|
||||
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
||||
400 => Retry::give_up(EmbedError::rest_bad_request(
|
||||
error_response,
|
||||
configuration_source,
|
||||
)),
|
||||
500..=599 => {
|
||||
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
||||
}
|
||||
402..=499 => {
|
||||
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
|
||||
}
|
||||
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
||||
})
|
||||
}
|
||||
Err(ureq::Error::Transport(transport)) => {
|
||||
Err(Retry::retry_later(EmbedError::rest_network(transport)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn response_to_embedding(
|
||||
response: ureq::Response,
|
||||
data: &EmbedderData,
|
||||
expected_count: usize,
|
||||
expected_dimensions: Option<usize>,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let response: serde_json::Value =
|
||||
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||
|
||||
let embeddings = data.response.extract_embeddings(response)?;
|
||||
|
||||
if embeddings.len() != expected_count {
|
||||
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
||||
}
|
||||
|
||||
if let Some(dimensions) = expected_dimensions {
|
||||
for embedding in &embeddings {
|
||||
if embedding.dimension() != dimensions {
|
||||
return Err(EmbedError::rest_unexpected_dimension(
|
||||
dimensions,
|
||||
embedding.dimension(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
||||
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
||||
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Request {
|
||||
template: ValueTemplate,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
|
||||
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self { template })
|
||||
}
|
||||
|
||||
fn input_type(&self) -> InputType {
|
||||
if self.template.has_array_value() {
|
||||
InputType::TextArray
|
||||
} else {
|
||||
InputType::Text
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inject_texts<S: Serialize>(
|
||||
&self,
|
||||
texts: impl IntoIterator<Item = S>,
|
||||
) -> serde_json::Value {
|
||||
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Response {
|
||||
template: ValueTemplate,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
|
||||
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
||||
{
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||
}
|
||||
};
|
||||
|
||||
match (template.has_array_value(), request.template.has_array_value()) {
|
||||
(true, true) | (false, false) => Ok(Self {template}),
|
||||
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
||||
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_embeddings(
|
||||
&self,
|
||||
response: serde_json::Value,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
||||
Ok(extracted_values) => extracted_values,
|
||||
Err(error) => {
|
||||
let error_message =
|
||||
error.error_message("response", "{{embedding}}", "an array of numbers");
|
||||
return Err(EmbedError::rest_extraction_error(error_message));
|
||||
}
|
||||
};
|
||||
let embeddings: Vec<Embeddings<f32>> =
|
||||
extracted_values.into_iter().map(Embeddings::from_single_embedding).collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
770
crates/milli/src/vector/settings.rs
Normal file
770
crates/milli/src/vector/settings.rs
Normal file
|
@ -0,0 +1,770 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
use deserr::Deserr;
|
||||
use roaring::RoaringBitmap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{ollama, openai, DistributionShift};
|
||||
use crate::prompt::{default_max_bytes, PromptData};
|
||||
use crate::update::Setting;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::UserError;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct EmbeddingSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub source: Setting<EmbedderSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub model: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub revision: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub dimensions: Setting<usize>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub binary_quantized: Setting<bool>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub document_template: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub document_template_max_bytes: Setting<usize>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub url: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub request: Setting<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub response: Setting<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub headers: Setting<BTreeMap<String, String>>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub distribution: Setting<DistributionShift>,
|
||||
}
|
||||
|
||||
pub fn check_unset<T>(
|
||||
key: &Setting<T>,
|
||||
field: &'static str,
|
||||
source: EmbedderSource,
|
||||
embedder_name: &str,
|
||||
) -> Result<(), UserError> {
|
||||
if matches!(key, Setting::NotSet) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(UserError::InvalidFieldForSource {
|
||||
embedder_name: embedder_name.to_owned(),
|
||||
source_: source,
|
||||
field,
|
||||
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
|
||||
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates what action should take place during a reindexing operation for an embedder
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ReindexAction {
|
||||
/// An indexing operation should take place for this embedder, keeping existing vectors
|
||||
/// and checking whether the document template changed or not
|
||||
RegeneratePrompts,
|
||||
/// An indexing operation should take place for all documents for this embedder, removing existing vectors
|
||||
/// (except userProvided ones)
|
||||
FullReindex,
|
||||
}
|
||||
|
||||
pub enum SettingsDiff {
|
||||
Remove,
|
||||
Reindex { action: ReindexAction, updated_settings: EmbeddingSettings, quantize: bool },
|
||||
UpdateWithoutReindex { updated_settings: EmbeddingSettings, quantize: bool },
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct EmbedderAction {
|
||||
pub was_quantized: bool,
|
||||
pub is_being_quantized: bool,
|
||||
pub write_back: Option<WriteBackToDocuments>,
|
||||
pub reindex: Option<ReindexAction>,
|
||||
}
|
||||
|
||||
impl EmbedderAction {
|
||||
pub fn is_being_quantized(&self) -> bool {
|
||||
self.is_being_quantized
|
||||
}
|
||||
|
||||
pub fn write_back(&self) -> Option<&WriteBackToDocuments> {
|
||||
self.write_back.as_ref()
|
||||
}
|
||||
|
||||
pub fn reindex(&self) -> Option<&ReindexAction> {
|
||||
self.reindex.as_ref()
|
||||
}
|
||||
|
||||
pub fn with_is_being_quantized(mut self, quantize: bool) -> Self {
|
||||
self.is_being_quantized = quantize;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_write_back(write_back: WriteBackToDocuments, was_quantized: bool) -> Self {
|
||||
Self {
|
||||
was_quantized,
|
||||
is_being_quantized: false,
|
||||
write_back: Some(write_back),
|
||||
reindex: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_reindex(reindex: ReindexAction, was_quantized: bool) -> Self {
|
||||
Self { was_quantized, is_being_quantized: false, write_back: None, reindex: Some(reindex) }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WriteBackToDocuments {
|
||||
pub embedder_id: u8,
|
||||
pub user_provided: RoaringBitmap,
|
||||
}
|
||||
|
||||
impl SettingsDiff {
|
||||
pub fn from_settings(
|
||||
embedder_name: &str,
|
||||
old: EmbeddingSettings,
|
||||
new: Setting<EmbeddingSettings>,
|
||||
) -> Result<Self, UserError> {
|
||||
let ret = match new {
|
||||
Setting::Set(new) => {
|
||||
let EmbeddingSettings {
|
||||
mut source,
|
||||
mut model,
|
||||
mut revision,
|
||||
mut api_key,
|
||||
mut dimensions,
|
||||
mut document_template,
|
||||
mut url,
|
||||
mut request,
|
||||
mut response,
|
||||
mut distribution,
|
||||
mut headers,
|
||||
mut document_template_max_bytes,
|
||||
binary_quantized: mut binary_quantize,
|
||||
} = old;
|
||||
|
||||
let EmbeddingSettings {
|
||||
source: new_source,
|
||||
model: new_model,
|
||||
revision: new_revision,
|
||||
api_key: new_api_key,
|
||||
dimensions: new_dimensions,
|
||||
document_template: new_document_template,
|
||||
url: new_url,
|
||||
request: new_request,
|
||||
response: new_response,
|
||||
distribution: new_distribution,
|
||||
headers: new_headers,
|
||||
document_template_max_bytes: new_document_template_max_bytes,
|
||||
binary_quantized: new_binary_quantize,
|
||||
} = new;
|
||||
|
||||
if matches!(binary_quantize, Setting::Set(true))
|
||||
&& matches!(new_binary_quantize, Setting::Set(false))
|
||||
{
|
||||
return Err(UserError::InvalidDisableBinaryQuantization {
|
||||
embedder_name: embedder_name.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut reindex_action = None;
|
||||
|
||||
// **Warning**: do not use short-circuiting || here, we want all these operations applied
|
||||
if source.apply(new_source) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
// when the source changes, we need to reapply the default settings for the new source
|
||||
apply_default_for_source(
|
||||
&source,
|
||||
&mut model,
|
||||
&mut revision,
|
||||
&mut dimensions,
|
||||
&mut url,
|
||||
&mut request,
|
||||
&mut response,
|
||||
&mut document_template,
|
||||
&mut document_template_max_bytes,
|
||||
&mut headers,
|
||||
)
|
||||
}
|
||||
if model.apply(new_model) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if revision.apply(new_revision) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if dimensions.apply(new_dimensions) {
|
||||
match source {
|
||||
// regenerate on dimensions change in OpenAI since truncation is supported
|
||||
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
|
||||
ReindexAction::push_action(
|
||||
&mut reindex_action,
|
||||
ReindexAction::FullReindex,
|
||||
);
|
||||
}
|
||||
// for all other embedders, the parameter is a hint that should not be able to change the result
|
||||
// and so won't cause a reindex by itself.
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let binary_quantize_changed = binary_quantize.apply(new_binary_quantize);
|
||||
if url.apply(new_url) {
|
||||
match source {
|
||||
// do not regenerate on an url change in OpenAI
|
||||
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {}
|
||||
_ => {
|
||||
ReindexAction::push_action(
|
||||
&mut reindex_action,
|
||||
ReindexAction::FullReindex,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if request.apply(new_request) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if response.apply(new_response) {
|
||||
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
|
||||
}
|
||||
if document_template.apply(new_document_template) {
|
||||
ReindexAction::push_action(
|
||||
&mut reindex_action,
|
||||
ReindexAction::RegeneratePrompts,
|
||||
);
|
||||
}
|
||||
|
||||
if document_template_max_bytes.apply(new_document_template_max_bytes) {
|
||||
let previous_document_template_max_bytes =
|
||||
document_template_max_bytes.set().unwrap_or(default_max_bytes().get());
|
||||
let new_document_template_max_bytes =
|
||||
new_document_template_max_bytes.set().unwrap_or(default_max_bytes().get());
|
||||
|
||||
// only reindex if the size increased. Reasoning:
|
||||
// - size decrease is a performance optimization, so we don't reindex and we keep the more accurate vectors
|
||||
// - size increase is an accuracy optimization, so we want to reindex
|
||||
if new_document_template_max_bytes > previous_document_template_max_bytes {
|
||||
ReindexAction::push_action(
|
||||
&mut reindex_action,
|
||||
ReindexAction::RegeneratePrompts,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
distribution.apply(new_distribution);
|
||||
api_key.apply(new_api_key);
|
||||
headers.apply(new_headers);
|
||||
|
||||
let updated_settings = EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
request,
|
||||
response,
|
||||
distribution,
|
||||
headers,
|
||||
document_template_max_bytes,
|
||||
binary_quantized: binary_quantize,
|
||||
};
|
||||
|
||||
match reindex_action {
|
||||
Some(action) => Self::Reindex {
|
||||
action,
|
||||
updated_settings,
|
||||
quantize: binary_quantize_changed,
|
||||
},
|
||||
None => Self::UpdateWithoutReindex {
|
||||
updated_settings,
|
||||
quantize: binary_quantize_changed,
|
||||
},
|
||||
}
|
||||
}
|
||||
Setting::Reset => Self::Remove,
|
||||
Setting::NotSet => {
|
||||
Self::UpdateWithoutReindex { updated_settings: old, quantize: false }
|
||||
}
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl ReindexAction {
|
||||
fn push_action(this: &mut Option<Self>, other: Self) {
|
||||
*this = match (*this, other) {
|
||||
(_, ReindexAction::FullReindex) => Some(ReindexAction::FullReindex),
|
||||
(Some(ReindexAction::FullReindex), _) => Some(ReindexAction::FullReindex),
|
||||
(_, ReindexAction::RegeneratePrompts) => Some(ReindexAction::RegeneratePrompts),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)] // private function
|
||||
fn apply_default_for_source(
|
||||
source: &Setting<EmbedderSource>,
|
||||
model: &mut Setting<String>,
|
||||
revision: &mut Setting<String>,
|
||||
dimensions: &mut Setting<usize>,
|
||||
url: &mut Setting<String>,
|
||||
request: &mut Setting<serde_json::Value>,
|
||||
response: &mut Setting<serde_json::Value>,
|
||||
document_template: &mut Setting<String>,
|
||||
document_template_max_bytes: &mut Setting<usize>,
|
||||
headers: &mut Setting<BTreeMap<String, String>>,
|
||||
) {
|
||||
match source {
|
||||
Setting::Set(EmbedderSource::HuggingFace) => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::Reset;
|
||||
*dimensions = Setting::NotSet;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
*response = Setting::NotSet;
|
||||
*headers = Setting::NotSet;
|
||||
}
|
||||
Setting::Set(EmbedderSource::Ollama) => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
*response = Setting::NotSet;
|
||||
*headers = Setting::NotSet;
|
||||
}
|
||||
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
|
||||
*model = Setting::Reset;
|
||||
*revision = Setting::NotSet;
|
||||
*dimensions = Setting::NotSet;
|
||||
*url = Setting::Reset;
|
||||
*request = Setting::NotSet;
|
||||
*response = Setting::NotSet;
|
||||
*headers = Setting::NotSet;
|
||||
}
|
||||
Setting::Set(EmbedderSource::Rest) => {
|
||||
*model = Setting::NotSet;
|
||||
*revision = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::Reset;
|
||||
*request = Setting::Reset;
|
||||
*response = Setting::Reset;
|
||||
*headers = Setting::Reset;
|
||||
}
|
||||
Setting::Set(EmbedderSource::UserProvided) => {
|
||||
*model = Setting::NotSet;
|
||||
*revision = Setting::NotSet;
|
||||
*dimensions = Setting::Reset;
|
||||
*url = Setting::NotSet;
|
||||
*request = Setting::NotSet;
|
||||
*response = Setting::NotSet;
|
||||
*document_template = Setting::NotSet;
|
||||
*document_template_max_bytes = Setting::NotSet;
|
||||
*headers = Setting::NotSet;
|
||||
}
|
||||
Setting::NotSet => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_set<T>(
|
||||
key: &Setting<T>,
|
||||
field: &'static str,
|
||||
source: EmbedderSource,
|
||||
embedder_name: &str,
|
||||
) -> Result<(), UserError> {
|
||||
if matches!(key, Setting::Set(_)) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(UserError::MissingFieldForSource {
|
||||
field,
|
||||
source_: source,
|
||||
embedder_name: embedder_name.to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingSettings {
|
||||
pub const SOURCE: &'static str = "source";
|
||||
pub const MODEL: &'static str = "model";
|
||||
pub const REVISION: &'static str = "revision";
|
||||
pub const API_KEY: &'static str = "apiKey";
|
||||
pub const DIMENSIONS: &'static str = "dimensions";
|
||||
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
||||
pub const DOCUMENT_TEMPLATE_MAX_BYTES: &'static str = "documentTemplateMaxBytes";
|
||||
|
||||
pub const URL: &'static str = "url";
|
||||
pub const REQUEST: &'static str = "request";
|
||||
pub const RESPONSE: &'static str = "response";
|
||||
pub const HEADERS: &'static str = "headers";
|
||||
|
||||
pub const DISTRIBUTION: &'static str = "distribution";
|
||||
|
||||
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
||||
match field {
|
||||
Self::SOURCE => &[
|
||||
EmbedderSource::HuggingFace,
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::UserProvided,
|
||||
EmbedderSource::Rest,
|
||||
EmbedderSource::Ollama,
|
||||
],
|
||||
Self::MODEL => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||
}
|
||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||
Self::API_KEY => {
|
||||
&[EmbedderSource::OpenAi, EmbedderSource::Ollama, EmbedderSource::Rest]
|
||||
}
|
||||
Self::DIMENSIONS => &[
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::UserProvided,
|
||||
EmbedderSource::Ollama,
|
||||
EmbedderSource::Rest,
|
||||
],
|
||||
Self::DOCUMENT_TEMPLATE => &[
|
||||
EmbedderSource::HuggingFace,
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::Ollama,
|
||||
EmbedderSource::Rest,
|
||||
],
|
||||
Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi],
|
||||
Self::REQUEST => &[EmbedderSource::Rest],
|
||||
Self::RESPONSE => &[EmbedderSource::Rest],
|
||||
Self::HEADERS => &[EmbedderSource::Rest],
|
||||
Self::DISTRIBUTION => &[
|
||||
EmbedderSource::HuggingFace,
|
||||
EmbedderSource::Ollama,
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::Rest,
|
||||
EmbedderSource::UserProvided,
|
||||
],
|
||||
_other => unreachable!("unknown field"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
|
||||
match source {
|
||||
EmbedderSource::OpenAi => &[
|
||||
Self::SOURCE,
|
||||
Self::MODEL,
|
||||
Self::API_KEY,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::DIMENSIONS,
|
||||
Self::DISTRIBUTION,
|
||||
Self::URL,
|
||||
],
|
||||
EmbedderSource::HuggingFace => &[
|
||||
Self::SOURCE,
|
||||
Self::MODEL,
|
||||
Self::REVISION,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::DISTRIBUTION,
|
||||
],
|
||||
EmbedderSource::Ollama => &[
|
||||
Self::SOURCE,
|
||||
Self::MODEL,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::URL,
|
||||
Self::API_KEY,
|
||||
Self::DIMENSIONS,
|
||||
Self::DISTRIBUTION,
|
||||
],
|
||||
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
|
||||
EmbedderSource::Rest => &[
|
||||
Self::SOURCE,
|
||||
Self::API_KEY,
|
||||
Self::DIMENSIONS,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::URL,
|
||||
Self::REQUEST,
|
||||
Self::RESPONSE,
|
||||
Self::HEADERS,
|
||||
Self::DISTRIBUTION,
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
|
||||
if let Setting::Set(EmbeddingSettings {
|
||||
source: source @ (Setting::NotSet | Setting::Reset),
|
||||
..
|
||||
}) = setting
|
||||
{
|
||||
*source = Setting::Set(EmbedderSource::default())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_openai_model(setting: &mut Setting<EmbeddingSettings>) {
|
||||
if let Setting::Set(EmbeddingSettings {
|
||||
source: Setting::Set(EmbedderSource::OpenAi),
|
||||
model: model @ (Setting::NotSet | Setting::Reset),
|
||||
..
|
||||
}) = setting
|
||||
{
|
||||
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum EmbedderSource {
|
||||
#[default]
|
||||
OpenAi,
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
UserProvided,
|
||||
Rest,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbedderSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
EmbedderSource::OpenAi => "openAi",
|
||||
EmbedderSource::HuggingFace => "huggingFace",
|
||||
EmbedderSource::UserProvided => "userProvided",
|
||||
EmbedderSource::Ollama => "ollama",
|
||||
EmbedderSource::Rest => "rest",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
fn from(value: EmbeddingConfig) -> Self {
|
||||
let EmbeddingConfig { embedder_options, prompt, quantized } = value;
|
||||
let document_template_max_bytes =
|
||||
Setting::Set(prompt.max_bytes.unwrap_or(default_max_bytes()).get());
|
||||
match embedder_options {
|
||||
super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions {
|
||||
model,
|
||||
revision,
|
||||
distribution,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::HuggingFace),
|
||||
model: Setting::Set(model),
|
||||
revision: Setting::some_or_not_set(revision),
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
document_template_max_bytes,
|
||||
url: Setting::NotSet,
|
||||
request: Setting::NotSet,
|
||||
response: Setting::NotSet,
|
||||
headers: Setting::NotSet,
|
||||
distribution: Setting::some_or_not_set(distribution),
|
||||
binary_quantized: Setting::some_or_not_set(quantized),
|
||||
},
|
||||
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
|
||||
url,
|
||||
api_key,
|
||||
embedding_model,
|
||||
dimensions,
|
||||
distribution,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::OpenAi),
|
||||
model: Setting::Set(embedding_model.name().to_owned()),
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
document_template_max_bytes,
|
||||
url: Setting::some_or_not_set(url),
|
||||
request: Setting::NotSet,
|
||||
response: Setting::NotSet,
|
||||
headers: Setting::NotSet,
|
||||
distribution: Setting::some_or_not_set(distribution),
|
||||
binary_quantized: Setting::some_or_not_set(quantized),
|
||||
},
|
||||
super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions {
|
||||
embedding_model,
|
||||
url,
|
||||
api_key,
|
||||
distribution,
|
||||
dimensions,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::Ollama),
|
||||
model: Setting::Set(embedding_model),
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
document_template_max_bytes,
|
||||
url: Setting::some_or_not_set(url),
|
||||
request: Setting::NotSet,
|
||||
response: Setting::NotSet,
|
||||
headers: Setting::NotSet,
|
||||
distribution: Setting::some_or_not_set(distribution),
|
||||
binary_quantized: Setting::some_or_not_set(quantized),
|
||||
},
|
||||
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
||||
dimensions,
|
||||
distribution,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::UserProvided),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(dimensions),
|
||||
document_template: Setting::NotSet,
|
||||
document_template_max_bytes: Setting::NotSet,
|
||||
url: Setting::NotSet,
|
||||
request: Setting::NotSet,
|
||||
response: Setting::NotSet,
|
||||
headers: Setting::NotSet,
|
||||
distribution: Setting::some_or_not_set(distribution),
|
||||
binary_quantized: Setting::some_or_not_set(quantized),
|
||||
},
|
||||
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||
api_key,
|
||||
dimensions,
|
||||
url,
|
||||
request,
|
||||
response,
|
||||
distribution,
|
||||
headers,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::Rest),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::some_or_not_set(api_key),
|
||||
dimensions: Setting::some_or_not_set(dimensions),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
document_template_max_bytes,
|
||||
url: Setting::Set(url),
|
||||
request: Setting::Set(request),
|
||||
response: Setting::Set(response),
|
||||
distribution: Setting::some_or_not_set(distribution),
|
||||
headers: Setting::Set(headers),
|
||||
binary_quantized: Setting::some_or_not_set(quantized),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
fn from(value: EmbeddingSettings) -> Self {
|
||||
let mut this = Self::default();
|
||||
let EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
document_template_max_bytes,
|
||||
url,
|
||||
request,
|
||||
response,
|
||||
distribution,
|
||||
headers,
|
||||
binary_quantized,
|
||||
} = value;
|
||||
|
||||
this.quantized = binary_quantized.set();
|
||||
|
||||
if let Some(source) = source.set() {
|
||||
match source {
|
||||
EmbedderSource::OpenAi => {
|
||||
let mut options = super::openai::EmbedderOptions::with_default_model(None);
|
||||
if let Some(model) = model.set() {
|
||||
if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
|
||||
options.embedding_model = model;
|
||||
}
|
||||
}
|
||||
if let Some(url) = url.set() {
|
||||
options.url = Some(url);
|
||||
}
|
||||
if let Some(api_key) = api_key.set() {
|
||||
options.api_key = Some(api_key);
|
||||
}
|
||||
if let Some(dimensions) = dimensions.set() {
|
||||
options.dimensions = Some(dimensions);
|
||||
}
|
||||
options.distribution = distribution.set();
|
||||
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
||||
}
|
||||
EmbedderSource::Ollama => {
|
||||
let mut options: ollama::EmbedderOptions =
|
||||
super::ollama::EmbedderOptions::with_default_model(
|
||||
api_key.set(),
|
||||
url.set(),
|
||||
dimensions.set(),
|
||||
);
|
||||
if let Some(model) = model.set() {
|
||||
options.embedding_model = model;
|
||||
}
|
||||
|
||||
options.distribution = distribution.set();
|
||||
this.embedder_options = super::EmbedderOptions::Ollama(options);
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
let mut options = super::hf::EmbedderOptions::default();
|
||||
if let Some(model) = model.set() {
|
||||
options.model = model;
|
||||
// Reset the revision if we are setting the model.
|
||||
// This allows the following:
|
||||
// "huggingFace": {} -> default model with default revision
|
||||
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
|
||||
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
|
||||
options.revision = None;
|
||||
}
|
||||
if let Some(revision) = revision.set() {
|
||||
options.revision = Some(revision);
|
||||
}
|
||||
options.distribution = distribution.set();
|
||||
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
|
||||
}
|
||||
EmbedderSource::UserProvided => {
|
||||
this.embedder_options =
|
||||
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
||||
dimensions: dimensions.set().unwrap(),
|
||||
distribution: distribution.set(),
|
||||
});
|
||||
}
|
||||
EmbedderSource::Rest => {
|
||||
this.embedder_options =
|
||||
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||
api_key: api_key.set(),
|
||||
dimensions: dimensions.set(),
|
||||
url: url.set().unwrap(),
|
||||
request: request.set().unwrap(),
|
||||
response: response.set().unwrap(),
|
||||
distribution: distribution.set(),
|
||||
headers: headers.set().unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Setting::Set(template) = document_template {
|
||||
let max_bytes = document_template_max_bytes
|
||||
.set()
|
||||
.and_then(NonZeroUsize::new)
|
||||
.unwrap_or(default_max_bytes());
|
||||
|
||||
this.prompt = PromptData { template, max_bytes: Some(max_bytes) }
|
||||
}
|
||||
|
||||
this
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue