mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-29 16:24:26 +01:00
feat: add new models and ability to override dimensions
This commit is contained in:
parent
84235a63df
commit
fb705116a6
@ -17,6 +17,7 @@ pub struct Embedder {
|
|||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
pub embedding_model: EmbeddingModel,
|
pub embedding_model: EmbeddingModel,
|
||||||
|
pub dimensions: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
@ -41,34 +42,54 @@ pub enum EmbeddingModel {
|
|||||||
#[serde(rename = "text-embedding-ada-002")]
|
#[serde(rename = "text-embedding-ada-002")]
|
||||||
#[deserr(rename = "text-embedding-ada-002")]
|
#[deserr(rename = "text-embedding-ada-002")]
|
||||||
TextEmbeddingAda002,
|
TextEmbeddingAda002,
|
||||||
|
|
||||||
|
#[serde(rename = "text-embedding-3-small")]
|
||||||
|
#[deserr(rename = "text-embedding-3-small")]
|
||||||
|
TextEmbedding3Small,
|
||||||
|
|
||||||
|
#[serde(rename = "text-embedding-3-large")]
|
||||||
|
#[deserr(rename = "text-embedding-3-large")]
|
||||||
|
TextEmbedding3Large,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbeddingModel {
|
impl EmbeddingModel {
|
||||||
pub fn supported_models() -> &'static [&'static str] {
|
pub fn supported_models() -> &'static [&'static str] {
|
||||||
&["text-embedding-ada-002"]
|
&["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn max_token(&self) -> usize {
|
pub fn max_token(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
EmbeddingModel::TextEmbeddingAda002 => 8191,
|
EmbeddingModel::TextEmbeddingAda002 => 8191,
|
||||||
|
EmbeddingModel::TextEmbedding3Large => 8191,
|
||||||
|
EmbeddingModel::TextEmbedding3Small => 8191,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
pub fn dimensions(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
EmbeddingModel::TextEmbeddingAda002 => 1536,
|
EmbeddingModel::TextEmbeddingAda002 => 1536,
|
||||||
|
|
||||||
|
//Default value for the model
|
||||||
|
EmbeddingModel::TextEmbedding3Large => 1536,
|
||||||
|
|
||||||
|
//Default value for the model
|
||||||
|
EmbeddingModel::TextEmbedding3Small => 3072,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn name(&self) -> &'static str {
|
pub fn name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
|
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
|
||||||
|
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
|
||||||
|
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_name(name: &str) -> Option<Self> {
|
pub fn from_name(name: &str) -> Option<Self> {
|
||||||
match name {
|
match name {
|
||||||
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
|
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
|
||||||
|
"text-embedding-3-large" => Some(EmbeddingModel::TextEmbedding3Large),
|
||||||
|
"text-embedding-3-small" => Some(EmbeddingModel::TextEmbedding3Small),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,6 +99,20 @@ impl EmbeddingModel {
|
|||||||
EmbeddingModel::TextEmbeddingAda002 => {
|
EmbeddingModel::TextEmbeddingAda002 => {
|
||||||
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
||||||
}
|
}
|
||||||
|
EmbeddingModel::TextEmbedding3Large => {
|
||||||
|
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
||||||
|
}
|
||||||
|
EmbeddingModel::TextEmbedding3Small => {
|
||||||
|
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_optional_dimensions_supported(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
EmbeddingModel::TextEmbeddingAda002 => false,
|
||||||
|
EmbeddingModel::TextEmbedding3Large => true,
|
||||||
|
EmbeddingModel::TextEmbedding3Small => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -86,11 +121,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
|||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn with_default_model(api_key: Option<String>) -> Self {
|
pub fn with_default_model(api_key: Option<String>) -> Self {
|
||||||
Self { api_key, embedding_model: Default::default() }
|
Self { api_key, embedding_model: Default::default(), dimensions: None }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
|
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
|
||||||
Self { api_key, embedding_model }
|
Self { api_key, embedding_model, dimensions: None }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,7 +272,15 @@ impl Embedder {
|
|||||||
for text in texts {
|
for text in texts {
|
||||||
log::trace!("Received prompt: {}", text.as_ref())
|
log::trace!("Received prompt: {}", text.as_ref())
|
||||||
}
|
}
|
||||||
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
let request = OpenAiRequest {
|
||||||
|
model: self.options.embedding_model.name(),
|
||||||
|
input: texts,
|
||||||
|
dimension: if self.options.embedding_model.is_optional_dimensions_supported() {
|
||||||
|
self.options.dimensions.as_ref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
};
|
||||||
let response = client
|
let response = client
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
.post(OPENAI_EMBEDDINGS_URL)
|
||||||
.json(&request)
|
.json(&request)
|
||||||
@ -366,7 +409,7 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
pub fn dimensions(&self) -> usize {
|
||||||
self.options.embedding_model.dimensions()
|
self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
@ -431,6 +474,7 @@ impl Retry {
|
|||||||
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
|
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
|
||||||
model: &'a str,
|
model: &'a str,
|
||||||
input: &'a [S],
|
input: &'a [S],
|
||||||
|
dimension: Option<&'a usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
@ -208,6 +208,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||||||
if let Some(api_key) = api_key.set() {
|
if let Some(api_key) = api_key.set() {
|
||||||
options.api_key = Some(api_key);
|
options.api_key = Some(api_key);
|
||||||
}
|
}
|
||||||
|
if let Some(dimensions) = dimensions.set() {
|
||||||
|
options.dimensions = Some(dimensions);
|
||||||
|
}
|
||||||
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
||||||
}
|
}
|
||||||
EmbedderSource::HuggingFace => {
|
EmbedderSource::HuggingFace => {
|
||||||
|
Loading…
Reference in New Issue
Block a user