mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-02 17:45:46 +01:00
Merge #4371
4371: Fixes embedder issues r=irevoire a=dureuill # Pull Request ## Related issue Fixes #4361 Fixes #4370 ## What does this PR do? - Truncate tokens to 512 for Hugging Face embedders - Move the tokio runtime to OpenAI so that we no longer have a thread with rayon -> tokio -> rayon - Spawn a new reqwest client after each new runtime to avoid spurious runtime error ## Manual tests - embedding failing document from `@CaroFG` with hugging face - embedding movies with hugging face - embedding and searching movies with openai Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
commit
049bd45849
@ -339,9 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
embedder: Arc<Embedder>,
|
embedder: Arc<Embedder>,
|
||||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||||
|
|
||||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
|
||||||
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
|
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
|
||||||
|
|
||||||
// docid, state with embedding
|
// docid, state with embedding
|
||||||
@ -375,11 +373,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
current_chunk_ids.push(docid);
|
current_chunk_ids.push(docid);
|
||||||
|
|
||||||
if chunks.len() == chunks.capacity() {
|
if chunks.len() == chunks.capacity() {
|
||||||
let chunked_embeds = rt
|
let chunked_embeds = embedder
|
||||||
.block_on(
|
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
|
||||||
embedder
|
|
||||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
|
||||||
)
|
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
@ -396,8 +391,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
|
|
||||||
// send last chunk
|
// send last chunk
|
||||||
if !chunks.is_empty() {
|
if !chunks.is_empty() {
|
||||||
let chunked_embeds = rt
|
let chunked_embeds = embedder
|
||||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
.embed_chunks(std::mem::take(&mut chunks))
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
for (docid, embeddings) in chunks_ids
|
for (docid, embeddings) in chunks_ids
|
||||||
@ -410,13 +405,15 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !current_chunk.is_empty() {
|
if !current_chunk.is_empty() {
|
||||||
let embeds = rt
|
let embeds = embedder
|
||||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
if let Some(embeds) = embeds.first() {
|
||||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||||
|
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +67,10 @@ pub enum EmbedErrorKind {
|
|||||||
OpenAiUnhandledStatusCode(u16),
|
OpenAiUnhandledStatusCode(u16),
|
||||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||||
ManualEmbed(String),
|
ManualEmbed(String),
|
||||||
|
#[error("could not initialize asynchronous runtime: {0}")]
|
||||||
|
OpenAiRuntimeInit(std::io::Error),
|
||||||
|
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||||
|
InitWebClient(reqwest::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedError {
|
impl EmbedError {
|
||||||
@ -117,6 +121,14 @@ impl EmbedError {
|
|||||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||||
|
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
@ -183,10 +195,6 @@ impl NewEmbedderError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
|
||||||
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||||
}
|
}
|
||||||
@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind {
|
|||||||
#[error("loading model failed: {0}")]
|
#[error("loading model failed: {0}")]
|
||||||
LoadModel(candle_core::Error),
|
LoadModel(candle_core::Error),
|
||||||
// openai
|
// openai
|
||||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
|
||||||
InitWebClient(reqwest::Error),
|
|
||||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
||||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,8 @@ impl Embedder {
|
|||||||
let token_ids = tokens
|
let token_ids = tokens
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tokens| {
|
.map(|tokens| {
|
||||||
let tokens = tokens.get_ids().to_vec();
|
let mut tokens = tokens.get_ids().to_vec();
|
||||||
|
tokens.truncate(512);
|
||||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||||
|
@ -163,18 +163,24 @@ impl Embedder {
|
|||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
Embedder::OpenAi(embedder) => {
|
||||||
|
let client = embedder.new_client()?;
|
||||||
|
embedder.embed(texts, &client).await
|
||||||
|
}
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed_chunks(
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// - if called from an asynchronous context
|
||||||
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings};
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
client: reqwest::Client,
|
headers: reqwest::header::HeaderMap,
|
||||||
tokenizer: tiktoken_rs::CoreBPE,
|
tokenizer: tiktoken_rs::CoreBPE,
|
||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
}
|
}
|
||||||
@ -95,6 +95,13 @@ impl EmbedderOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
|
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||||
|
reqwest::ClientBuilder::new()
|
||||||
|
.default_headers(self.headers.clone())
|
||||||
|
.build()
|
||||||
|
.map_err(EmbedError::openai_initialize_web_client)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
let mut inferred_api_key = Default::default();
|
let mut inferred_api_key = Default::default();
|
||||||
@ -111,25 +118,25 @@ impl Embedder {
|
|||||||
reqwest::header::CONTENT_TYPE,
|
reqwest::header::CONTENT_TYPE,
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
let client = reqwest::ClientBuilder::new()
|
|
||||||
.default_headers(headers)
|
|
||||||
.build()
|
|
||||||
.map_err(NewEmbedderError::openai_initialize_web_client)?;
|
|
||||||
|
|
||||||
// looking at the code it is very unclear that this can actually fail.
|
// looking at the code it is very unclear that this can actually fail.
|
||||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||||
|
|
||||||
Ok(Self { options, client, tokenizer })
|
Ok(Self { options, headers, tokenizer })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub async fn embed(
|
||||||
|
&self,
|
||||||
|
texts: Vec<String>,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
let mut tokenized = false;
|
let mut tokenized = false;
|
||||||
|
|
||||||
for attempt in 0..7 {
|
for attempt in 0..7 {
|
||||||
let result = if tokenized {
|
let result = if tokenized {
|
||||||
self.try_embed_tokenized(&texts).await
|
self.try_embed_tokenized(&texts, client).await
|
||||||
} else {
|
} else {
|
||||||
self.try_embed(&texts).await
|
self.try_embed(&texts, client).await
|
||||||
};
|
};
|
||||||
|
|
||||||
let retry_duration = match result {
|
let retry_duration = match result {
|
||||||
@ -145,9 +152,9 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let result = if tokenized {
|
let result = if tokenized {
|
||||||
self.try_embed_tokenized(&texts).await
|
self.try_embed_tokenized(&texts, client).await
|
||||||
} else {
|
} else {
|
||||||
self.try_embed(&texts).await
|
self.try_embed(&texts, client).await
|
||||||
};
|
};
|
||||||
|
|
||||||
result.map_err(Retry::into_error)
|
result.map_err(Retry::into_error)
|
||||||
@ -225,13 +232,13 @@ impl Embedder {
|
|||||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||||
&self,
|
&self,
|
||||||
texts: &[S],
|
texts: &[S],
|
||||||
|
client: &reqwest::Client,
|
||||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||||
for text in texts {
|
for text in texts {
|
||||||
log::trace!("Received prompt: {}", text.as_ref())
|
log::trace!("Received prompt: {}", text.as_ref())
|
||||||
}
|
}
|
||||||
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
||||||
let response = self
|
let response = client
|
||||||
.client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
.post(OPENAI_EMBEDDINGS_URL)
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
@ -256,7 +263,11 @@ impl Embedder {
|
|||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
|
async fn try_embed_tokenized(
|
||||||
|
&self,
|
||||||
|
text: &[String],
|
||||||
|
client: &reqwest::Client,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||||
pub const OVERLAP_SIZE: usize = 200;
|
pub const OVERLAP_SIZE: usize = 200;
|
||||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||||
for text in text {
|
for text in text {
|
||||||
@ -264,7 +275,7 @@ impl Embedder {
|
|||||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||||
let len = encoded.len();
|
let len = encoded.len();
|
||||||
if len < max_token_count {
|
if len < max_token_count {
|
||||||
all_embeddings.append(&mut self.try_embed(&[text]).await?);
|
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,22 +284,26 @@ impl Embedder {
|
|||||||
Embeddings::new(self.options.embedding_model.dimensions());
|
Embeddings::new(self.options.embedding_model.dimensions());
|
||||||
while tokens.len() > max_token_count {
|
while tokens.len() > max_token_count {
|
||||||
let window = &tokens[..max_token_count];
|
let window = &tokens[..max_token_count];
|
||||||
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
|
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
||||||
|
|
||||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||||
}
|
}
|
||||||
|
|
||||||
// end of text
|
// end of text
|
||||||
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
|
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
||||||
|
|
||||||
all_embeddings.push(embeddings_for_prompt);
|
all_embeddings.push(embeddings_for_prompt);
|
||||||
}
|
}
|
||||||
Ok(all_embeddings)
|
Ok(all_embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
async fn embed_tokens(
|
||||||
|
&self,
|
||||||
|
tokens: &[usize],
|
||||||
|
client: &reqwest::Client,
|
||||||
|
) -> Result<Embedding, Retry> {
|
||||||
for attempt in 0..9 {
|
for attempt in 0..9 {
|
||||||
let duration = match self.try_embed_tokens(tokens).await {
|
let duration = match self.try_embed_tokens(tokens, client).await {
|
||||||
Ok(embedding) => return Ok(embedding),
|
Ok(embedding) => return Ok(embedding),
|
||||||
Err(retry) => retry.into_duration(attempt),
|
Err(retry) => retry.into_duration(attempt),
|
||||||
}
|
}
|
||||||
@ -297,14 +312,19 @@ impl Embedder {
|
|||||||
tokio::time::sleep(duration).await;
|
tokio::time::sleep(duration).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
|
self.try_embed_tokens(tokens, client)
|
||||||
|
.await
|
||||||
|
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
async fn try_embed_tokens(
|
||||||
|
&self,
|
||||||
|
tokens: &[usize],
|
||||||
|
client: &reqwest::Client,
|
||||||
|
) -> Result<Embedding, Retry> {
|
||||||
let request =
|
let request =
|
||||||
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
|
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
|
||||||
let response = self
|
let response = client
|
||||||
.client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
.post(OPENAI_EMBEDDINGS_URL)
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
@ -322,12 +342,19 @@ impl Embedder {
|
|||||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||||||
.await
|
.enable_io()
|
||||||
|
.enable_time()
|
||||||
|
.build()
|
||||||
|
.map_err(EmbedError::openai_runtime_init)?;
|
||||||
|
let client = self.new_client()?;
|
||||||
|
rt.block_on(futures::future::try_join_all(
|
||||||
|
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
Loading…
Reference in New Issue
Block a user