pass dimensions only when defined

This commit is contained in:
Louis Dureuil 2024-02-07 11:03:00 +01:00
parent 517f5332d6
commit 74c180267e
No known key found for this signature in database

View File

@ -271,11 +271,7 @@ impl Embedder {
let request = OpenAiRequest {
model: self.options.embedding_model.name(),
input: texts,
dimension: if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions.as_ref()
} else {
None
},
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
@ -360,8 +356,11 @@ impl Embedder {
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request =
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
let request = OpenAiTokensRequest {
model: self.options.embedding_model.name(),
input: tokens,
dimensions: self.overriden_dimensions(),
};
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
@ -414,6 +413,14 @@ impl Embedder {
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
fn overriden_dimensions(&self) -> Option<usize> {
if self.options.embedding_model.supports_overriding_dimensions() {
self.options.dimensions
} else {
None
}
}
}
// retrying in case of failure
@ -473,13 +480,16 @@ impl Retry {
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
model: &'a str,
input: &'a [S],
dimension: Option<&'a usize>,
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Serialize)]
struct OpenAiTokensRequest<'a> {
model: &'a str,
input: &'a [usize],
#[serde(skip_serializing_if = "Option::is_none")]
dimensions: Option<usize>,
}
#[derive(Debug, Deserialize)]