diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 524f83b80..20013d8e8 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -17,6 +17,7 @@ pub struct Embedder { pub struct EmbedderOptions { pub api_key: Option, pub embedding_model: EmbeddingModel, + pub dimensions: Option, } #[derive( @@ -41,34 +42,54 @@ pub enum EmbeddingModel { #[serde(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")] TextEmbeddingAda002, + + #[serde(rename = "text-embedding-3-small")] + #[deserr(rename = "text-embedding-3-small")] + TextEmbedding3Small, + + #[serde(rename = "text-embedding-3-large")] + #[deserr(rename = "text-embedding-3-large")] + TextEmbedding3Large, } impl EmbeddingModel { pub fn supported_models() -> &'static [&'static str] { - &["text-embedding-ada-002"] + &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"] } pub fn max_token(&self) -> usize { match self { EmbeddingModel::TextEmbeddingAda002 => 8191, + EmbeddingModel::TextEmbedding3Large => 8191, + EmbeddingModel::TextEmbedding3Small => 8191, } } pub fn dimensions(&self) -> usize { match self { EmbeddingModel::TextEmbeddingAda002 => 1536, + + //Default value for the model + EmbeddingModel::TextEmbedding3Large => 1536, + + //Default value for the model + EmbeddingModel::TextEmbedding3Small => 3072, } } pub fn name(&self) -> &'static str { match self { EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", + EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large", + EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small", } } pub fn from_name(name: &str) -> Option { match name { "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), + "text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large), + "text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small), _ => None, } } @@ -78,6 +99,20 @@ impl EmbeddingModel { EmbeddingModel::TextEmbeddingAda002 => { Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) } + EmbeddingModel::TextEmbedding3Large => { + Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) + } + EmbeddingModel::TextEmbedding3Small => { + Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) + } + } + } + + pub fn is_optional_dimensions_supported(&self) -> bool { + match self { + EmbeddingModel::TextEmbeddingAda002 => false, + EmbeddingModel::TextEmbedding3Large => true, + EmbeddingModel::TextEmbedding3Small => true, } } } @@ -86,11 +121,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; impl EmbedderOptions { pub fn with_default_model(api_key: Option) -> Self { - Self { api_key, embedding_model: Default::default() } + Self { api_key, embedding_model: Default::default(), dimensions: None } } pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { - Self { api_key, embedding_model } + Self { api_key, embedding_model, dimensions: None } } } @@ -237,7 +272,15 @@ impl Embedder { for text in texts { log::trace!("Received prompt: {}", text.as_ref()) } - let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; + let request = OpenAiRequest { + model: self.options.embedding_model.name(), + input: texts, + dimension: if self.options.embedding_model.is_optional_dimensions_supported() { + self.options.dimensions.as_ref() + } else { + None + }, + }; let response = client .post(OPENAI_EMBEDDINGS_URL) .json(&request) @@ -366,7 +409,7 @@ impl Embedder { } pub fn dimensions(&self) -> usize { - self.options.embedding_model.dimensions() + self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions()) } pub fn distribution(&self) -> Option { @@ -431,6 +474,7 @@ impl Retry { struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { model: &'a str, input: &'a [S], + dimension: Option<&'a usize>, } #[derive(Debug, Serialize)] diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 37fb80452..dac129ccd 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -208,6 +208,9 @@ impl From for EmbeddingConfig { if let Some(api_key) = api_key.set() { options.api_key = Some(api_key); } + if let Some(dimensions) = dimensions.set() { + options.dimensions = Some(dimensions); + } this.embedder_options = super::EmbedderOptions::OpenAi(options); } EmbedderSource::HuggingFace => {