feat: add new models and ability to override dimensions

This commit is contained in:
Gosti 2024-01-30 16:32:57 +01:00 committed by Louis Dureuil
parent 84235a63df
commit fb705116a6
No known key found for this signature in database
2 changed files with 52 additions and 5 deletions

View File

@ -17,6 +17,7 @@ pub struct Embedder {
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub api_key: Option<String>, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>,
} }
#[derive( #[derive(
@ -41,34 +42,54 @@ pub enum EmbeddingModel {
#[serde(rename = "text-embedding-ada-002")] #[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")]
TextEmbeddingAda002, TextEmbeddingAda002,
#[serde(rename = "text-embedding-3-small")]
#[deserr(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
#[deserr(rename = "text-embedding-3-large")]
TextEmbedding3Large,
} }
impl EmbeddingModel { impl EmbeddingModel {
pub fn supported_models() -> &'static [&'static str] { pub fn supported_models() -> &'static [&'static str] {
&["text-embedding-ada-002"] &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
} }
pub fn max_token(&self) -> usize { pub fn max_token(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 8191, EmbeddingModel::TextEmbeddingAda002 => 8191,
EmbeddingModel::TextEmbedding3Large => 8191,
EmbeddingModel::TextEmbedding3Small => 8191,
} }
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 1536, EmbeddingModel::TextEmbeddingAda002 => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Large => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Small => 3072,
} }
} }
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
} }
} }
pub fn from_name(name: &str) -> Option<Self> { pub fn from_name(name: &str) -> Option<Self> {
match name { match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
"text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
"text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
_ => None, _ => None,
} }
} }
@ -78,6 +99,20 @@ impl EmbeddingModel {
EmbeddingModel::TextEmbeddingAda002 => { EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
} }
EmbeddingModel::TextEmbedding3Large => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
EmbeddingModel::TextEmbedding3Small => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
}
}
pub fn is_optional_dimensions_supported(&self) -> bool {
match self {
EmbeddingModel::TextEmbeddingAda002 => false,
EmbeddingModel::TextEmbedding3Large => true,
EmbeddingModel::TextEmbedding3Small => true,
} }
} }
} }
@ -86,11 +121,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() } Self { api_key, embedding_model: Default::default(), dimensions: None }
} }
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model } Self { api_key, embedding_model, dimensions: None }
} }
} }
@ -237,7 +272,15 @@ impl Embedder {
for text in texts { for text in texts {
log::trace!("Received prompt: {}", text.as_ref()) log::trace!("Received prompt: {}", text.as_ref())
} }
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimension: if self.options.embedding_model.is_optional_dimensions_supported() {
self.options.dimensions.as_ref()
} else {
None
},
};
let response = client let response = client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
@ -366,7 +409,7 @@ impl Embedder {
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions() self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions())
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
@ -431,6 +474,7 @@ impl Retry {
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str, model: &'a str,
input: &'a [S], input: &'a [S],
dimension: Option<&'a usize>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]

View File

@ -208,6 +208,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(api_key) = api_key.set() { if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key); options.api_key = Some(api_key);
} }
if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions);
}
this.embedder_options = super::EmbedderOptions::OpenAi(options); this.embedder_options = super::EmbedderOptions::OpenAi(options);
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {