Don't use a runtime in extract_embedder, use it only for OpenAI

This commit is contained in:
Louis Dureuil 2024-01-29 11:23:18 +01:00
parent 1555870088
commit fbf5f2a392
No known key found for this signature in database
4 changed files with 85 additions and 49 deletions

View file

@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
client: reqwest::Client,
headers: reqwest::header::HeaderMap,
tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions,
}
@ -95,6 +95,13 @@ impl EmbedderOptions {
}
impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
@ -111,25 +118,25 @@ impl Embedder {
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
let client = reqwest::ClientBuilder::new()
.default_headers(headers)
.build()
.map_err(NewEmbedderError::openai_initialize_web_client)?;
// looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, client, tokenizer })
Ok(Self { options, headers, tokenizer })
}
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
pub async fn embed(
&self,
texts: Vec<String>,
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false;
for attempt in 0..7 {
let result = if tokenized {
self.try_embed_tokenized(&texts).await
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts).await
self.try_embed(&texts, client).await
};
let retry_duration = match result {
@ -145,9 +152,9 @@ impl Embedder {
}
let result = if tokenized {
self.try_embed_tokenized(&texts).await
self.try_embed_tokenized(&texts, client).await
} else {
self.try_embed(&texts).await
self.try_embed(&texts, client).await
};
result.map_err(Retry::into_error)
@ -225,13 +232,13 @@ impl Embedder {
async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self,
texts: &[S],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts {
log::trace!("Received prompt: {}", text.as_ref())
}
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
let response = self
.client
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
@ -256,7 +263,11 @@ impl Embedder {
.collect())
}
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
async fn try_embed_tokenized(
&self,
text: &[String],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len());
for text in text {
@ -264,7 +275,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len();
if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text]).await?);
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
continue;
}
@ -273,22 +284,26 @@ impl Embedder {
Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count {
let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
}
// end of text
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
all_embeddings.push(embeddings_for_prompt);
}
Ok(all_embeddings)
}
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
async fn embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens).await {
let duration = match self.try_embed_tokens(tokens, client).await {
Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt),
}
@ -297,14 +312,19 @@ impl Embedder {
tokio::time::sleep(duration).await;
}
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
self.try_embed_tokens(tokens, client)
.await
.map_err(|retry| Retry::give_up(retry.into_error()))
}
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
async fn try_embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request =
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
let response = self
.client
let response = client
.post(OPENAI_EMBEDDINGS_URL)
.json(&request)
.send()
@ -322,12 +342,19 @@ impl Embedder {
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
}
pub async fn embed_chunks(
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
.await
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
}
pub fn chunk_count_hint(&self) -> usize {