mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 12:54:26 +01:00
OpenAI sync
This commit is contained in:
parent
bc58e8a310
commit
c3d02f092d
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -3378,6 +3378,7 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"ureq",
|
||||||
"uuid",
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -91,6 +91,7 @@ liquid = "0.26.4"
|
|||||||
arroy = "0.2.0"
|
arroy = "0.2.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
|
ureq = { version = "2.9.6", features = ["json"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
mimalloc = { version = "0.1.39", default-features = false }
|
mimalloc = { version = "0.1.39", default-features = false }
|
||||||
|
@ -53,17 +53,17 @@ pub enum EmbedErrorKind {
|
|||||||
#[error("could not run model: {0}")]
|
#[error("could not run model: {0}")]
|
||||||
ModelForward(candle_core::Error),
|
ModelForward(candle_core::Error),
|
||||||
#[error("could not reach OpenAI: {0}")]
|
#[error("could not reach OpenAI: {0}")]
|
||||||
OpenAiNetwork(reqwest::Error),
|
OpenAiNetwork(ureq::Transport),
|
||||||
#[error("unexpected response from OpenAI: {0}")]
|
#[error("unexpected response from OpenAI: {0}")]
|
||||||
OpenAiUnexpected(reqwest::Error),
|
OpenAiUnexpected(ureq::Error),
|
||||||
#[error("could not authenticate against OpenAI: {0}")]
|
#[error("could not authenticate against OpenAI: {0:?}")]
|
||||||
OpenAiAuth(OpenAiError),
|
OpenAiAuth(Option<OpenAiError>),
|
||||||
#[error("sent too many requests to OpenAI: {0}")]
|
#[error("sent too many requests to OpenAI: {0:?}")]
|
||||||
OpenAiTooManyRequests(OpenAiError),
|
OpenAiTooManyRequests(Option<OpenAiError>),
|
||||||
#[error("received internal error from OpenAI: {0:?}")]
|
#[error("received internal error from OpenAI: {0:?}")]
|
||||||
OpenAiInternalServerError(Option<OpenAiError>),
|
OpenAiInternalServerError(Option<OpenAiError>),
|
||||||
#[error("sent too many tokens in a request to OpenAI: {0}")]
|
#[error("sent too many tokens in a request to OpenAI: {0:?}")]
|
||||||
OpenAiTooManyTokens(OpenAiError),
|
OpenAiTooManyTokens(Option<OpenAiError>),
|
||||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
||||||
OpenAiUnhandledStatusCode(u16),
|
OpenAiUnhandledStatusCode(u16),
|
||||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||||
@ -102,19 +102,19 @@ impl EmbedError {
|
|||||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn openai_network(inner: reqwest::Error) -> Self {
|
pub fn openai_network(inner: ureq::Transport) -> Self {
|
||||||
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
|
pub fn openai_unexpected(inner: ureq::Error) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
|
pub(crate) fn openai_auth_error(inner: Option<OpenAiError>) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
|
pub(crate) fn openai_too_many_requests(inner: Option<OpenAiError>) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ impl EmbedError {
|
|||||||
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
|
pub(crate) fn openai_too_many_tokens(inner: Option<OpenAiError>) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,7 +220,7 @@ impl NewEmbedderError {
|
|||||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||||
Self {
|
Self {
|
||||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||||
fault: FaultSource::Runtime,
|
fault: FaultSource::Runtime,
|
||||||
|
@ -131,7 +131,7 @@ impl Embedder {
|
|||||||
|
|
||||||
let embeddings = this
|
let embeddings = this
|
||||||
.embed(vec!["test".into()])
|
.embed(vec!["test".into()])
|
||||||
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
this.dimensions = embeddings.first().unwrap().dimension();
|
this.dimensions = embeddings.first().unwrap().dimension();
|
||||||
|
|
||||||
Ok(this)
|
Ok(this)
|
||||||
|
@ -98,7 +98,7 @@ pub enum Embedder {
|
|||||||
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
||||||
HuggingFace(hf::Embedder),
|
HuggingFace(hf::Embedder),
|
||||||
/// An embedder based on making embedding queries against the OpenAI API.
|
/// An embedder based on making embedding queries against the OpenAI API.
|
||||||
OpenAi(openai::Embedder),
|
OpenAi(openai::sync::Embedder),
|
||||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||||
UserProvided(manual::Embedder),
|
UserProvided(manual::Embedder),
|
||||||
Ollama(ollama::Embedder),
|
Ollama(ollama::Embedder),
|
||||||
@ -201,7 +201,7 @@ impl Embedder {
|
|||||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||||
Ok(match options {
|
Ok(match options {
|
||||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?),
|
||||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||||
EmbedderOptions::UserProvided(options) => {
|
EmbedderOptions::UserProvided(options) => {
|
||||||
Self::UserProvided(manual::Embedder::new(options))
|
Self::UserProvided(manual::Embedder::new(options))
|
||||||
@ -218,10 +218,7 @@ impl Embedder {
|
|||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => {
|
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||||
let client = embedder.new_client()?;
|
|
||||||
embedder.embed(texts, &client).await
|
|
||||||
}
|
|
||||||
Embedder::Ollama(embedder) => {
|
Embedder::Ollama(embedder) => {
|
||||||
let client = embedder.new_client()?;
|
let client = embedder.new_client()?;
|
||||||
embedder.embed(texts, &client).await
|
embedder.embed(texts, &client).await
|
||||||
|
@ -1,18 +1,10 @@
|
|||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
use reqwest::StatusCode;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::{DistributionShift, Embedding, Embeddings};
|
use super::{DistributionShift, Embedding, Embeddings};
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Embedder {
|
|
||||||
headers: reqwest::header::HeaderMap,
|
|
||||||
tokenizer: tiktoken_rs::CoreBPE,
|
|
||||||
options: EmbedderOptions,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
@ -125,298 +117,6 @@ impl EmbedderOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
|
||||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
|
||||||
reqwest::ClientBuilder::new()
|
|
||||||
.default_headers(self.headers.clone())
|
|
||||||
.build()
|
|
||||||
.map_err(EmbedError::openai_initialize_web_client)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
|
||||||
let mut inferred_api_key = Default::default();
|
|
||||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
|
||||||
inferred_api_key = infer_api_key();
|
|
||||||
&inferred_api_key
|
|
||||||
});
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::AUTHORIZATION,
|
|
||||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
|
||||||
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
|
||||||
);
|
|
||||||
headers.insert(
|
|
||||||
reqwest::header::CONTENT_TYPE,
|
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
|
||||||
);
|
|
||||||
|
|
||||||
// looking at the code it is very unclear that this can actually fail.
|
|
||||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
|
||||||
|
|
||||||
Ok(Self { options, headers, tokenizer })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn embed(
|
|
||||||
&self,
|
|
||||||
texts: Vec<String>,
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
|
||||||
let mut tokenized = false;
|
|
||||||
|
|
||||||
for attempt in 0..7 {
|
|
||||||
let result = if tokenized {
|
|
||||||
self.try_embed_tokenized(&texts, client).await
|
|
||||||
} else {
|
|
||||||
self.try_embed(&texts, client).await
|
|
||||||
};
|
|
||||||
|
|
||||||
let retry_duration = match result {
|
|
||||||
Ok(embeddings) => return Ok(embeddings),
|
|
||||||
Err(retry) => {
|
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
|
||||||
tokenized |= retry.must_tokenize();
|
|
||||||
retry.into_duration(attempt)
|
|
||||||
}
|
|
||||||
}?;
|
|
||||||
|
|
||||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
|
||||||
tracing::warn!(
|
|
||||||
"Attempt #{}, retrying after {}ms.",
|
|
||||||
attempt,
|
|
||||||
retry_duration.as_millis()
|
|
||||||
);
|
|
||||||
tokio::time::sleep(retry_duration).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = if tokenized {
|
|
||||||
self.try_embed_tokenized(&texts, client).await
|
|
||||||
} else {
|
|
||||||
self.try_embed(&texts, client).await
|
|
||||||
};
|
|
||||||
|
|
||||||
result.map_err(Retry::into_error)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
|
||||||
if !response.status().is_success() {
|
|
||||||
match response.status() {
|
|
||||||
StatusCode::UNAUTHORIZED => {
|
|
||||||
let error_response: OpenAiErrorResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
return Err(Retry::give_up(EmbedError::openai_auth_error(
|
|
||||||
error_response.error,
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
let error_response: OpenAiErrorResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
|
|
||||||
error_response.error,
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
| StatusCode::BAD_GATEWAY
|
|
||||||
| StatusCode::SERVICE_UNAVAILABLE => {
|
|
||||||
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
|
|
||||||
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
|
|
||||||
error_response.ok().map(|error_response| error_response.error),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
StatusCode::BAD_REQUEST => {
|
|
||||||
// Most probably, one text contained too many tokens
|
|
||||||
let error_response: OpenAiErrorResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
|
|
||||||
|
|
||||||
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
|
|
||||||
error_response.error,
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
code => {
|
|
||||||
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
|
|
||||||
code.as_u16(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
|
||||||
&self,
|
|
||||||
texts: &[S],
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
|
||||||
for text in texts {
|
|
||||||
tracing::trace!("Received prompt: {}", text.as_ref())
|
|
||||||
}
|
|
||||||
let request = OpenAiRequest {
|
|
||||||
model: self.options.embedding_model.name(),
|
|
||||||
input: texts,
|
|
||||||
dimensions: self.overriden_dimensions(),
|
|
||||||
};
|
|
||||||
let response = client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_network)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
let response = Self::check_response(response).await?;
|
|
||||||
|
|
||||||
let response: OpenAiResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
tracing::trace!("response: {:?}", response.data);
|
|
||||||
|
|
||||||
Ok(response
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_embed_tokenized(
|
|
||||||
&self,
|
|
||||||
text: &[String],
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
|
||||||
pub const OVERLAP_SIZE: usize = 200;
|
|
||||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
|
||||||
for text in text {
|
|
||||||
let max_token_count = self.options.embedding_model.max_token();
|
|
||||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
|
||||||
let len = encoded.len();
|
|
||||||
if len < max_token_count {
|
|
||||||
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut tokens = encoded.as_slice();
|
|
||||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
|
||||||
while tokens.len() > max_token_count {
|
|
||||||
let window = &tokens[..max_token_count];
|
|
||||||
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
|
||||||
|
|
||||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
|
||||||
}
|
|
||||||
|
|
||||||
// end of text
|
|
||||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
|
||||||
|
|
||||||
all_embeddings.push(embeddings_for_prompt);
|
|
||||||
}
|
|
||||||
Ok(all_embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn embed_tokens(
|
|
||||||
&self,
|
|
||||||
tokens: &[usize],
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Embedding, Retry> {
|
|
||||||
for attempt in 0..9 {
|
|
||||||
let duration = match self.try_embed_tokens(tokens, client).await {
|
|
||||||
Ok(embedding) => return Ok(embedding),
|
|
||||||
Err(retry) => retry.into_duration(attempt),
|
|
||||||
}
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
tokio::time::sleep(duration).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.try_embed_tokens(tokens, client)
|
|
||||||
.await
|
|
||||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_embed_tokens(
|
|
||||||
&self,
|
|
||||||
tokens: &[usize],
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Embedding, Retry> {
|
|
||||||
let request = OpenAiTokensRequest {
|
|
||||||
model: self.options.embedding_model.name(),
|
|
||||||
input: tokens,
|
|
||||||
dimensions: self.overriden_dimensions(),
|
|
||||||
};
|
|
||||||
let response = client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_network)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
let response = Self::check_response(response).await?;
|
|
||||||
|
|
||||||
let mut response: OpenAiResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn embed_chunks(
|
|
||||||
&self,
|
|
||||||
text_chunks: Vec<Vec<String>>,
|
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
|
||||||
.enable_io()
|
|
||||||
.enable_time()
|
|
||||||
.build()
|
|
||||||
.map_err(EmbedError::openai_runtime_init)?;
|
|
||||||
let client = self.new_client()?;
|
|
||||||
rt.block_on(futures::future::try_join_all(
|
|
||||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
|
||||||
10
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
|
||||||
10
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
|
||||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
|
||||||
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
|
||||||
} else {
|
|
||||||
self.options.embedding_model.default_dimensions()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
|
||||||
self.options.embedding_model.distribution()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn overriden_dimensions(&self) -> Option<usize> {
|
|
||||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
|
||||||
self.options.dimensions
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrying in case of failure
|
// retrying in case of failure
|
||||||
|
|
||||||
pub struct Retry {
|
pub struct Retry {
|
||||||
@ -524,3 +224,257 @@ fn infer_api_key() -> String {
|
|||||||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub mod sync {
|
||||||
|
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest,
|
||||||
|
OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL,
|
||||||
|
};
|
||||||
|
use crate::vector::DistributionShift;
|
||||||
|
|
||||||
|
const REQUEST_PARALLELISM: usize = 10;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Embedder {
|
||||||
|
tokenizer: tiktoken_rs::CoreBPE,
|
||||||
|
options: super::EmbedderOptions,
|
||||||
|
bearer: String,
|
||||||
|
threads: rayon::ThreadPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedder {
|
||||||
|
pub fn new(options: super::EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
|
let mut inferred_api_key = Default::default();
|
||||||
|
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||||
|
inferred_api_key = super::infer_api_key();
|
||||||
|
&inferred_api_key
|
||||||
|
});
|
||||||
|
let bearer = format!("Bearer {api_key}");
|
||||||
|
|
||||||
|
// looking at the code it is very unclear that this can actually fail.
|
||||||
|
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||||
|
|
||||||
|
// FIXME: unwrap
|
||||||
|
let threads = rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(REQUEST_PARALLELISM)
|
||||||
|
.thread_name(|index| format!("embedder-chunk-{index}"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Self { options, bearer, tokenizer, threads })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
let mut tokenized = false;
|
||||||
|
|
||||||
|
let client = ureq::agent();
|
||||||
|
|
||||||
|
for attempt in 0..7 {
|
||||||
|
let result = if tokenized {
|
||||||
|
self.try_embed_tokenized(&texts, &client)
|
||||||
|
} else {
|
||||||
|
self.try_embed(&texts, &client)
|
||||||
|
};
|
||||||
|
|
||||||
|
let retry_duration = match result {
|
||||||
|
Ok(embeddings) => return Ok(embeddings),
|
||||||
|
Err(retry) => {
|
||||||
|
tracing::warn!("Failed: {}", retry.error);
|
||||||
|
tokenized |= retry.must_tokenize();
|
||||||
|
retry.into_duration(attempt)
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
|
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||||
|
tracing::warn!(
|
||||||
|
"Attempt #{}, retrying after {}ms.",
|
||||||
|
attempt,
|
||||||
|
retry_duration.as_millis()
|
||||||
|
);
|
||||||
|
std::thread::sleep(retry_duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = if tokenized {
|
||||||
|
self.try_embed_tokenized(&texts, &client)
|
||||||
|
} else {
|
||||||
|
self.try_embed(&texts, &client)
|
||||||
|
};
|
||||||
|
|
||||||
|
result.map_err(Retry::into_error)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_response(
|
||||||
|
response: Result<ureq::Response, ureq::Error>,
|
||||||
|
) -> Result<ureq::Response, Retry> {
|
||||||
|
match response {
|
||||||
|
Ok(response) => Ok(response),
|
||||||
|
Err(ureq::Error::Status(code, response)) => {
|
||||||
|
let error_response: Option<OpenAiErrorResponse> = response.into_json().ok();
|
||||||
|
let error = error_response.map(|response| response.error);
|
||||||
|
Err(match code {
|
||||||
|
401 => Retry::give_up(EmbedError::openai_auth_error(error)),
|
||||||
|
429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)),
|
||||||
|
400 => {
|
||||||
|
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||||
|
|
||||||
|
Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error))
|
||||||
|
}
|
||||||
|
500..=599 => {
|
||||||
|
Retry::retry_later(EmbedError::openai_internal_server_error(error))
|
||||||
|
}
|
||||||
|
x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(ureq::Error::Transport(transport)) => {
|
||||||
|
Err(Retry::retry_later(EmbedError::openai_network(transport)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||||
|
&self,
|
||||||
|
texts: &[S],
|
||||||
|
client: &ureq::Agent,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||||
|
for text in texts {
|
||||||
|
tracing::trace!("Received prompt: {}", text.as_ref())
|
||||||
|
}
|
||||||
|
let request = OpenAiRequest {
|
||||||
|
model: self.options.embedding_model.name(),
|
||||||
|
input: texts,
|
||||||
|
dimensions: self.overriden_dimensions(),
|
||||||
|
};
|
||||||
|
let response = client
|
||||||
|
.post(OPENAI_EMBEDDINGS_URL)
|
||||||
|
.set("Authorization", &self.bearer)
|
||||||
|
.send_json(&request);
|
||||||
|
|
||||||
|
let response = Self::check_response(response)?;
|
||||||
|
|
||||||
|
let response: OpenAiResponse = response
|
||||||
|
.into_json()
|
||||||
|
.map_err(EmbedError::openai_unexpected)
|
||||||
|
.map_err(Retry::retry_later)?;
|
||||||
|
|
||||||
|
tracing::trace!("response: {:?}", response.data);
|
||||||
|
|
||||||
|
Ok(response
|
||||||
|
.data
|
||||||
|
.into_iter()
|
||||||
|
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_embed_tokenized(
|
||||||
|
&self,
|
||||||
|
text: &[String],
|
||||||
|
client: &ureq::Agent,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||||
|
pub const OVERLAP_SIZE: usize = 200;
|
||||||
|
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||||
|
for text in text {
|
||||||
|
let max_token_count = self.options.embedding_model.max_token();
|
||||||
|
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||||
|
let len = encoded.len();
|
||||||
|
if len < max_token_count {
|
||||||
|
all_embeddings.append(&mut self.try_embed(&[text], client)?);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tokens = encoded.as_slice();
|
||||||
|
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||||
|
while tokens.len() > max_token_count {
|
||||||
|
let window = &tokens[..max_token_count];
|
||||||
|
embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap();
|
||||||
|
|
||||||
|
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||||
|
}
|
||||||
|
|
||||||
|
// end of text
|
||||||
|
embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap();
|
||||||
|
|
||||||
|
all_embeddings.push(embeddings_for_prompt);
|
||||||
|
}
|
||||||
|
Ok(all_embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result<Embedding, Retry> {
|
||||||
|
for attempt in 0..9 {
|
||||||
|
let duration = match self.try_embed_tokens(tokens, client) {
|
||||||
|
Ok(embedding) => return Ok(embedding),
|
||||||
|
Err(retry) => retry.into_duration(attempt),
|
||||||
|
}
|
||||||
|
.map_err(Retry::retry_later)?;
|
||||||
|
|
||||||
|
std::thread::sleep(duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.try_embed_tokens(tokens, client)
|
||||||
|
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_embed_tokens(
|
||||||
|
&self,
|
||||||
|
tokens: &[usize],
|
||||||
|
client: &ureq::Agent,
|
||||||
|
) -> Result<Embedding, Retry> {
|
||||||
|
let request = OpenAiTokensRequest {
|
||||||
|
model: self.options.embedding_model.name(),
|
||||||
|
input: tokens,
|
||||||
|
dimensions: self.overriden_dimensions(),
|
||||||
|
};
|
||||||
|
let response = client
|
||||||
|
.post(OPENAI_EMBEDDINGS_URL)
|
||||||
|
.set("Authorization", &self.bearer)
|
||||||
|
.send_json(&request);
|
||||||
|
|
||||||
|
let response = Self::check_response(response)?;
|
||||||
|
|
||||||
|
let mut response: OpenAiResponse = response
|
||||||
|
.into_json()
|
||||||
|
.map_err(EmbedError::openai_unexpected)
|
||||||
|
.map_err(Retry::retry_later)?;
|
||||||
|
|
||||||
|
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_chunks(
|
||||||
|
&self,
|
||||||
|
text_chunks: Vec<Vec<String>>,
|
||||||
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
|
self.threads
|
||||||
|
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||||
|
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
||||||
|
} else {
|
||||||
|
self.options.embedding_model.default_dimensions()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
|
self.options.embedding_model.distribution()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn overriden_dimensions(&self) -> Option<usize> {
|
||||||
|
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||||
|
self.options.dimensions
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user