Expose distribution in settings

This commit is contained in:
Louis Dureuil 2024-03-27 11:51:04 +01:00
parent 168ded3b9d
commit a25456120d
No known key found for this signature in database

View File

@ -2,7 +2,7 @@ use deserr::Deserr;
use serde::{Deserialize, Serialize};
use super::rest::InputType;
use super::{ollama, openai};
use super::{ollama, openai, DistributionShift};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
@ -48,6 +48,9 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_type: Setting<InputType>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub distribution: Setting<DistributionShift>,
}
pub fn check_unset<T>(
@ -101,6 +104,8 @@ impl EmbeddingSettings {
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType";
pub const DISTRIBUTION: &'static str = "distribution";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => &[
@ -132,6 +137,13 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest],
Self::DISTRIBUTION => &[
EmbedderSource::HuggingFace,
EmbedderSource::Ollama,
EmbedderSource::OpenAi,
EmbedderSource::Rest,
EmbedderSource::UserProvided,
],
_other => unreachable!("unknown field"),
}
}
@ -144,14 +156,24 @@ impl EmbeddingSettings {
Self::API_KEY,
Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS,
Self::DISTRIBUTION,
],
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::Ollama => {
&[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL, Self::API_KEY]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
EmbedderSource::HuggingFace => &[
Self::SOURCE,
Self::MODEL,
Self::REVISION,
Self::DOCUMENT_TEMPLATE,
Self::DISTRIBUTION,
],
EmbedderSource::Ollama => &[
Self::SOURCE,
Self::MODEL,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::API_KEY,
Self::DISTRIBUTION,
],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
EmbedderSource::Rest => &[
Self::SOURCE,
Self::API_KEY,
@ -163,6 +185,7 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE,
Self::DISTRIBUTION,
],
}
}
@ -283,6 +306,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
@ -297,6 +321,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama),
@ -311,6 +336,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
@ -325,11 +351,10 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,
// TODO: support distribution
distribution: _,
dimensions,
url,
query,
@ -337,6 +362,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings,
embedding_object,
input_type,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet,
@ -350,6 +376,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
}
}
@ -371,7 +398,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
path_to_embeddings,
embedding_object,
input_type,
distribution,
} = value;
if let Some(source) = source.set() {
match source {
EmbedderSource::OpenAi => {
@ -387,6 +416,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::Ollama => {
@ -399,6 +429,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
options.embedding_model = model;
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::Ollama(options);
}
EmbedderSource::HuggingFace => {
@ -415,12 +446,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(revision) = revision.set() {
options.revision = Some(revision);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
}
EmbedderSource::UserProvided => {
this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(),
distribution: distribution.set(),
});
}
EmbedderSource::Rest => {
@ -429,7 +462,6 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(),
url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query),
@ -441,6 +473,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
.set()
.unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type),
distribution: distribution.set(),
})
}
}