diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 20013d8e8..8712c7894 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -65,14 +65,10 @@ impl EmbeddingModel { } } - pub fn dimensions(&self) -> usize { + pub fn default_dimensions(&self) -> usize { match self { EmbeddingModel::TextEmbeddingAda002 => 1536, - - //Default value for the model EmbeddingModel::TextEmbedding3Large => 1536, - - //Default value for the model EmbeddingModel::TextEmbedding3Small => 3072, } } @@ -108,7 +104,7 @@ impl EmbeddingModel { } } - pub fn is_optional_dimensions_supported(&self) -> bool { + pub fn supports_overriding_dimensions(&self) -> bool { match self { EmbeddingModel::TextEmbeddingAda002 => false, EmbeddingModel::TextEmbedding3Large => true, @@ -275,7 +271,7 @@ impl Embedder { let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts, - dimension: if self.options.embedding_model.is_optional_dimensions_supported() { + dimension: if self.options.embedding_model.supports_overriding_dimensions() { self.options.dimensions.as_ref() } else { None @@ -323,8 +319,7 @@ impl Embedder { } let mut tokens = encoded.as_slice(); - let mut embeddings_for_prompt = - Embeddings::new(self.options.embedding_model.dimensions()); + let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); while tokens.len() > max_token_count { let window = &tokens[..max_token_count]; embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); @@ -409,7 +404,11 @@ impl Embedder { } pub fn dimensions(&self) -> usize { - self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions()) + if self.options.embedding_model.supports_overriding_dimensions() { + self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) + } else { + self.options.embedding_model.default_dimensions() + } } pub fn distribution(&self) -> Option {