4375: Feat: add new OpenAI models and ability to override dimensions r=dureuill a=Gosti

# Pull Request

Fixes #4394 

## Related discussion
https://github.com/orgs/meilisearch/discussions/677#discussioncomment-8306384

## What does this PR do?
- Add text-embedding-3-small
- Add text-embedding-3-large
- Add optional dimensions parameter for both new models


## Note
As the dimensions option is not available for text-embedding-ada-002 I've added a manual check to prevent, but I feel it could be implemented in a more idiomatic rust 

## PR checklist
Please check if your PR fulfills the following requirements:
- [x] Does this PR fix an existing issue, or have you listed the changes applied in the PR description (and why they are needed)?
- [x] Have you read the contributing guidelines?
- [x] Have you made sure that the title is accurate and descriptive of the changes?

Thank you so much for contributing to Meilisearch!


Co-authored-by: Gosti <gostitsog@gmail.com>
Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-02-07 16:20:15 +00:00 committed by GitHub
commit 3e120619fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 142 additions and 20 deletions

View File

@ -347,6 +347,9 @@ impl ErrorCode for milli::Error {
UserError::InvalidFieldForSource { .. } UserError::InvalidFieldForSource { .. }
| UserError::MissingFieldForSource { .. } | UserError::MissingFieldForSource { .. }
| UserError::InvalidOpenAiModel { .. } | UserError::InvalidOpenAiModel { .. }
| UserError::InvalidOpenAiModelDimensions { .. }
| UserError::InvalidOpenAiModelDimensionsMax { .. }
| UserError::InvalidSettingsDimensions { .. }
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,

View File

@ -227,6 +227,22 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
source_: crate::vector::settings::EmbedderSource, source_: crate::vector::settings::EmbedderSource,
embedder_name: String, embedder_name: String,
}, },
#[error("`.embedders.{embedder_name}.dimensions`: Model `{model}` does not support overriding its native dimensions of {expected_dimensions}. Found {dimensions}")]
InvalidOpenAiModelDimensions {
embedder_name: String,
model: &'static str,
dimensions: usize,
expected_dimensions: usize,
},
#[error("`.embedders.{embedder_name}.dimensions`: Model `{model}` does not support overriding its dimensions to a value higher than {max_dimensions}. Found {dimensions}")]
InvalidOpenAiModelDimensionsMax {
embedder_name: String,
model: &'static str,
dimensions: usize,
max_dimensions: usize,
},
#[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")]
InvalidSettingsDimensions { embedder_name: String },
} }
impl From<crate::vector::Error> for Error { impl From<crate::vector::Error> for Error {

View File

@ -974,6 +974,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
crate::vector::settings::EmbeddingSettings::apply_default_source( crate::vector::settings::EmbeddingSettings::apply_default_source(
&mut setting, &mut setting,
); );
crate::vector::settings::EmbeddingSettings::apply_default_openai_model(
&mut setting,
);
let setting = validate_embedding_settings(setting, &name)?; let setting = validate_embedding_settings(setting, &name)?;
changed = true; changed = true;
new_configs.insert(name, setting); new_configs.insert(name, setting);
@ -1119,6 +1122,14 @@ pub fn validate_embedding_settings(
let Setting::Set(settings) = settings else { return Ok(settings) }; let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
settings; settings;
if let Some(0) = dimensions.set() {
return Err(crate::error::UserError::InvalidSettingsDimensions {
embedder_name: name.to_owned(),
}
.into());
}
let Some(inferred_source) = source.set() else { let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings { return Ok(Setting::Set(EmbeddingSettings {
source, source,
@ -1132,14 +1143,34 @@ pub fn validate_embedding_settings(
match inferred_source { match inferred_source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
if let Setting::Set(model) = &model { if let Setting::Set(model) = &model {
crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or( let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
crate::error::UserError::InvalidOpenAiModel { .ok_or(crate::error::UserError::InvalidOpenAiModel {
embedder_name: name.to_owned(), embedder_name: name.to_owned(),
model: model.clone(), model: model.clone(),
}, })?;
)?; if let Setting::Set(dimensions) = dimensions {
if !model.supports_overriding_dimensions()
&& dimensions != model.default_dimensions()
{
return Err(crate::error::UserError::InvalidOpenAiModelDimensions {
embedder_name: name.to_owned(),
model: model.name(),
dimensions,
expected_dimensions: model.default_dimensions(),
}
.into());
}
if dimensions > model.default_dimensions() {
return Err(crate::error::UserError::InvalidOpenAiModelDimensionsMax {
embedder_name: name.to_owned(),
model: model.name(),
dimensions,
max_dimensions: model.default_dimensions(),
}
.into());
}
}
} }
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {

View File

@ -17,6 +17,7 @@ pub struct Embedder {
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub api_key: Option<String>, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>,
} }
#[derive( #[derive(
@ -41,34 +42,50 @@ pub enum EmbeddingModel {
#[serde(rename = "text-embedding-ada-002")] #[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")]
TextEmbeddingAda002, TextEmbeddingAda002,
#[serde(rename = "text-embedding-3-small")]
#[deserr(rename = "text-embedding-3-small")]
TextEmbedding3Small,
#[serde(rename = "text-embedding-3-large")]
#[deserr(rename = "text-embedding-3-large")]
TextEmbedding3Large,
} }
impl EmbeddingModel { impl EmbeddingModel {
pub fn supported_models() -> &'static [&'static str] { pub fn supported_models() -> &'static [&'static str] {
&["text-embedding-ada-002"] &["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
} }
pub fn max_token(&self) -> usize { pub fn max_token(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 8191, EmbeddingModel::TextEmbeddingAda002 => 8191,
EmbeddingModel::TextEmbedding3Large => 8191,
EmbeddingModel::TextEmbedding3Small => 8191,
} }
} }
pub fn dimensions(&self) -> usize { pub fn default_dimensions(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 1536, EmbeddingModel::TextEmbeddingAda002 => 1536,
EmbeddingModel::TextEmbedding3Large => 3072,
EmbeddingModel::TextEmbedding3Small => 1536,
} }
} }
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
} }
} }
pub fn from_name(name: &str) -> Option<Self> { pub fn from_name(name: &str) -> Option<Self> {
match name { match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
"text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
"text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
_ => None, _ => None,
} }
} }
@ -78,6 +95,20 @@ impl EmbeddingModel {
EmbeddingModel::TextEmbeddingAda002 => { EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
} }
EmbeddingModel::TextEmbedding3Large => {
Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 })
}
EmbeddingModel::TextEmbedding3Small => {
Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 })
}
}
}
pub fn supports_overriding_dimensions(&self) -> bool {
match self {
EmbeddingModel::TextEmbeddingAda002 => false,
EmbeddingModel::TextEmbedding3Large => true,
EmbeddingModel::TextEmbedding3Small => true,
} }
} }
} }
@ -86,11 +117,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() } Self { api_key, embedding_model: Default::default(), dimensions: None }
} }
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model } Self { api_key, embedding_model, dimensions: None }
} }
} }
@ -237,7 +268,11 @@ impl Embedder {
for text in texts { for text in texts {
log::trace!("Received prompt: {}", text.as_ref()) log::trace!("Received prompt: {}", text.as_ref())
} }
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimensions: self.overriden_dimensions(),
};
let response = client let response = client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
@ -280,8 +315,7 @@ impl Embedder {
} }
let mut tokens = encoded.as_slice(); let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt = let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count { while tokens.len() > max_token_count {
let window = &tokens[..max_token_count]; let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
@ -322,8 +356,11 @@ impl Embedder {
tokens: &[usize], tokens: &[usize],
client: &reqwest::Client, client: &reqwest::Client,
) -> Result<Embedding, Retry> { ) -> Result<Embedding, Retry> {
let request = let request = OpenAiTokensRequest {
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client let response = client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
@ -366,12 +403,24 @@ impl Embedder {
} }
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions() if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution() self.options.embedding_model.distribution()
} }
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
} }
// retrying in case of failure // retrying in case of failure
@ -431,12 +480,16 @@ impl Retry {
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> { struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str, model: &'a str,
input: &'a [S], input: &'a [S],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> { struct OpenAiTokensRequest<'a> {
model: &'a str, model: &'a str,
input: &'a [usize], input: &'a [usize],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]

View File

@ -1,6 +1,7 @@
use deserr::Deserr; use deserr::Deserr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::openai;
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
@ -82,7 +83,7 @@ impl EmbeddingSettings {
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::REVISION => &[EmbedderSource::HuggingFace], Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi], Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::UserProvided], Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
_other => unreachable!("unknown field"), _other => unreachable!("unknown field"),
} }
@ -90,9 +91,13 @@ impl EmbeddingSettings {
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] { pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
match source { match source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => &[
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE] Self::SOURCE,
} Self::MODEL,
Self::API_KEY,
Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS,
],
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
} }
@ -109,6 +114,17 @@ impl EmbeddingSettings {
*source = Setting::Set(EmbedderSource::default()) *source = Setting::Set(EmbedderSource::default())
} }
} }
pub(crate) fn apply_default_openai_model(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: Setting::Set(EmbedderSource::OpenAi),
model: model @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
}
}
} }
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
@ -176,7 +192,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
model: Setting::Set(options.embedding_model.name().to_owned()), model: Setting::Set(options.embedding_model.name().to_owned()),
revision: Setting::NotSet, revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(), api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: Setting::NotSet, dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
}, },
super::EmbedderOptions::UserProvided(options) => Self { super::EmbedderOptions::UserProvided(options) => Self {
@ -208,6 +224,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(api_key) = api_key.set() { if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key); options.api_key = Some(api_key);
} }
if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions);
}
this.embedder_options = super::EmbedderOptions::OpenAi(options); this.embedder_options = super::EmbedderOptions::OpenAi(options);
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {