Make sure the overriden dimensions are always used when embedding

This commit is contained in:
Louis Dureuil 2024-02-07 10:36:30 +01:00
parent fb705116a6
commit 7ae4013478
No known key found for this signature in database

View File

@ -65,14 +65,10 @@ impl EmbeddingModel {
}
}
pub fn dimensions(&self) -> usize {
pub fn default_dimensions(&self) -> usize {
match self {
EmbeddingModel::TextEmbeddingAda002 => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Large => 1536,
//Default value for the model
EmbeddingModel::TextEmbedding3Small => 3072,
}
}
@ -108,7 +104,7 @@ impl EmbeddingModel {
}
}
pub fn is_optional_dimensions_supported(&self) -> bool {
pub fn supports_overriding_dimensions(&self) -> bool {
match self {
EmbeddingModel::TextEmbeddingAda002 => false,
EmbeddingModel::TextEmbedding3Large => true,
@ -275,7 +271,7 @@ impl Embedder {
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimension: if self.options.embedding_model.is_optional_dimensions_supported() {
dimension: if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.as_ref()
} else {
None
@ -323,8 +319,7 @@ impl Embedder {
}
let mut tokens = encoded.as_slice();
let mut embeddings_for_prompt =
Embeddings::new(self.options.embedding_model.dimensions());
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
@ -409,7 +404,11 @@ impl Embedder {
}
pub fn dimensions(&self) -> usize {
self.options.dimensions.unwrap_or_else(|| self.options.embedding_model.dimensions())
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
} else {
self.options.embedding_model.default_dimensions()
}
}
pub fn distribution(&self) -> Option<DistributionShift> {