mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-30 00:34:26 +01:00
Merge #5039
5039: Add 3s timeout to embedding requests made during search r=irevoire a=dureuill # Pull Request ## Related issue Fixes #5032 ## What does this PR do? - Add a 3-second timeout to embedding requests against a remote embedder made in the context of search. The timeout triggers when there are failing requests due to rate-limiting. - Add a test of that timeout. Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
commit
2c1c33166d
@ -5201,9 +5201,10 @@ mod tests {
|
|||||||
|
|
||||||
let configs = index_scheduler.embedders(configs).unwrap();
|
let configs = index_scheduler.embedders(configs).unwrap();
|
||||||
let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
|
let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
|
||||||
let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap();
|
let beagle_embed =
|
||||||
let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap();
|
hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap();
|
||||||
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap();
|
let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap();
|
||||||
|
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap();
|
||||||
(fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
|
(fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -796,8 +796,10 @@ fn prepare_search<'t>(
|
|||||||
let span = tracing::trace_span!(target: "search::vector", "embed_one");
|
let span = tracing::trace_span!(target: "search::vector", "embed_one");
|
||||||
let _entered = span.enter();
|
let _entered = span.enter();
|
||||||
|
|
||||||
|
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
embedder
|
embedder
|
||||||
.embed_one(query.q.clone().unwrap())
|
.embed_one(query.q.clone().unwrap(), Some(deadline))
|
||||||
.map_err(milli::vector::Error::from)
|
.map_err(milli::vector::Error::from)
|
||||||
.map_err(milli::Error::from)?
|
.map_err(milli::Error::from)?
|
||||||
}
|
}
|
||||||
|
@ -137,13 +137,14 @@ fn long_text() -> &'static str {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn create_mock_tokenized() -> (MockServer, Value) {
|
async fn create_mock_tokenized() -> (MockServer, Value) {
|
||||||
create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false).await
|
create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false, false).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_mock_with_template(
|
async fn create_mock_with_template(
|
||||||
document_template: &str,
|
document_template: &str,
|
||||||
model_dimensions: ModelDimensions,
|
model_dimensions: ModelDimensions,
|
||||||
fallible: bool,
|
fallible: bool,
|
||||||
|
slow: bool,
|
||||||
) -> (MockServer, Value) {
|
) -> (MockServer, Value) {
|
||||||
let mock_server = MockServer::start().await;
|
let mock_server = MockServer::start().await;
|
||||||
const API_KEY: &str = "my-api-key";
|
const API_KEY: &str = "my-api-key";
|
||||||
@ -154,7 +155,11 @@ async fn create_mock_with_template(
|
|||||||
Mock::given(method("POST"))
|
Mock::given(method("POST"))
|
||||||
.and(path("/"))
|
.and(path("/"))
|
||||||
.respond_with(move |req: &Request| {
|
.respond_with(move |req: &Request| {
|
||||||
// 0. maybe return 500
|
// 0. wait for a long time
|
||||||
|
if slow {
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
}
|
||||||
|
// 1. maybe return 500
|
||||||
if fallible {
|
if fallible {
|
||||||
let attempt = attempt.fetch_add(1, Ordering::Relaxed);
|
let attempt = attempt.fetch_add(1, Ordering::Relaxed);
|
||||||
let failed = matches!(attempt % 4, 0 | 1 | 3);
|
let failed = matches!(attempt % 4, 0 | 1 | 3);
|
||||||
@ -167,7 +172,7 @@ async fn create_mock_with_template(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 1. check API key
|
// 3. check API key
|
||||||
match req.headers.get("Authorization") {
|
match req.headers.get("Authorization") {
|
||||||
Some(api_key) if api_key == API_KEY_BEARER => {
|
Some(api_key) if api_key == API_KEY_BEARER => {
|
||||||
{}
|
{}
|
||||||
@ -202,7 +207,7 @@ async fn create_mock_with_template(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 2. parse text inputs
|
// 3. parse text inputs
|
||||||
let query: serde_json::Value = match req.body_json() {
|
let query: serde_json::Value = match req.body_json() {
|
||||||
Ok(query) => query,
|
Ok(query) => query,
|
||||||
Err(_error) => return ResponseTemplate::new(400).set_body_json(
|
Err(_error) => return ResponseTemplate::new(400).set_body_json(
|
||||||
@ -223,7 +228,7 @@ async fn create_mock_with_template(
|
|||||||
panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}")
|
panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. for each text, find embedding in responses
|
// 4. for each text, find embedding in responses
|
||||||
let serde_json::Value::Array(inputs) = &query["input"] else {
|
let serde_json::Value::Array(inputs) = &query["input"] else {
|
||||||
panic!("Unexpected `input` value")
|
panic!("Unexpected `input` value")
|
||||||
};
|
};
|
||||||
@ -283,7 +288,7 @@ async fn create_mock_with_template(
|
|||||||
"embedding": embedding,
|
"embedding": embedding,
|
||||||
})).collect();
|
})).collect();
|
||||||
|
|
||||||
// 4. produce output from embeddings
|
// 5. produce output from embeddings
|
||||||
ResponseTemplate::new(200).set_body_json(json!({
|
ResponseTemplate::new(200).set_body_json(json!({
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
@ -317,23 +322,27 @@ const DOGGO_TEMPLATE: &str = r#"{%- if doc.gender == "F" -%}Une chienne nommée
|
|||||||
{%- endif %}, de race {{doc.breed}}."#;
|
{%- endif %}, de race {{doc.breed}}."#;
|
||||||
|
|
||||||
async fn create_mock() -> (MockServer, Value) {
|
async fn create_mock() -> (MockServer, Value) {
|
||||||
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false).await
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false, false).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_mock_dimensions() -> (MockServer, Value) {
|
async fn create_mock_dimensions() -> (MockServer, Value) {
|
||||||
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false).await
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false, false).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_mock_small_embedding_model() -> (MockServer, Value) {
|
async fn create_mock_small_embedding_model() -> (MockServer, Value) {
|
||||||
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false).await
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false, false).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_mock_legacy_embedding_model() -> (MockServer, Value) {
|
async fn create_mock_legacy_embedding_model() -> (MockServer, Value) {
|
||||||
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false).await
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false, false).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_fallible_mock() -> (MockServer, Value) {
|
async fn create_fallible_mock() -> (MockServer, Value) {
|
||||||
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true).await
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, false).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_slow_mock() -> (MockServer, Value) {
|
||||||
|
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, true).await
|
||||||
}
|
}
|
||||||
|
|
||||||
// basic test "it works"
|
// basic test "it works"
|
||||||
@ -1873,4 +1882,114 @@ async fn it_still_works() {
|
|||||||
]
|
]
|
||||||
"###);
|
"###);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test with a server that responds 500 on 3 out of 4 calls
|
||||||
|
#[actix_rt::test]
|
||||||
|
async fn timeout() {
|
||||||
|
let (_mock, setting) = create_slow_mock().await;
|
||||||
|
let server = get_server_vector().await;
|
||||||
|
let index = server.index("doggo");
|
||||||
|
|
||||||
|
let (response, code) = index
|
||||||
|
.update_settings(json!({
|
||||||
|
"embedders": {
|
||||||
|
"default": setting,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = server.wait_task(response.uid()).await;
|
||||||
|
snapshot!(task["status"], @r###""succeeded""###);
|
||||||
|
let documents = json!([
|
||||||
|
{"id": 0, "name": "kefir", "gender": "M", "birthyear": 2023, "breed": "Patou"},
|
||||||
|
]);
|
||||||
|
let (value, code) = index.add_documents(documents, None).await;
|
||||||
|
snapshot!(code, @"202 Accepted");
|
||||||
|
let task = index.wait_task(value.uid()).await;
|
||||||
|
snapshot!(task, @r###"
|
||||||
|
{
|
||||||
|
"uid": "[uid]",
|
||||||
|
"indexUid": "doggo",
|
||||||
|
"status": "succeeded",
|
||||||
|
"type": "documentAdditionOrUpdate",
|
||||||
|
"canceledBy": null,
|
||||||
|
"details": {
|
||||||
|
"receivedDocuments": 1,
|
||||||
|
"indexedDocuments": 1
|
||||||
|
},
|
||||||
|
"error": null,
|
||||||
|
"duration": "[duration]",
|
||||||
|
"enqueuedAt": "[date]",
|
||||||
|
"startedAt": "[date]",
|
||||||
|
"finishedAt": "[date]"
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let (documents, _code) = index
|
||||||
|
.get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() })
|
||||||
|
.await;
|
||||||
|
snapshot!(json_string!(documents, {".results.*._vectors.default.embeddings" => "[vector]"}), @r###"
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "kefir",
|
||||||
|
"gender": "M",
|
||||||
|
"birthyear": 2023,
|
||||||
|
"breed": "Patou",
|
||||||
|
"_vectors": {
|
||||||
|
"default": {
|
||||||
|
"embeddings": "[vector]",
|
||||||
|
"regenerate": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"offset": 0,
|
||||||
|
"limit": 20,
|
||||||
|
"total": 1
|
||||||
|
}
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let (response, code) = index
|
||||||
|
.search_post(json!({
|
||||||
|
"q": "grand chien de berger des montagnes",
|
||||||
|
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
snapshot!(code, @"200 OK");
|
||||||
|
snapshot!(json_string!(response["semanticHitCount"]), @"0");
|
||||||
|
snapshot!(json_string!(response["hits"]), @"[]");
|
||||||
|
|
||||||
|
let (response, code) = index
|
||||||
|
.search_post(json!({
|
||||||
|
"q": "grand chien de berger des montagnes",
|
||||||
|
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
snapshot!(code, @"200 OK");
|
||||||
|
snapshot!(json_string!(response["semanticHitCount"]), @"1");
|
||||||
|
snapshot!(json_string!(response["hits"]), @r###"
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"name": "kefir",
|
||||||
|
"gender": "M",
|
||||||
|
"birthyear": 2023,
|
||||||
|
"breed": "Patou"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"###);
|
||||||
|
|
||||||
|
let (response, code) = index
|
||||||
|
.search_post(json!({
|
||||||
|
"q": "grand chien de berger des montagnes",
|
||||||
|
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
|
||||||
|
}))
|
||||||
|
.await;
|
||||||
|
snapshot!(code, @"200 OK");
|
||||||
|
snapshot!(json_string!(response["semanticHitCount"]), @"0");
|
||||||
|
snapshot!(json_string!(response["hits"]), @"[]");
|
||||||
|
}
|
||||||
|
|
||||||
// test with a server that wrongly responds 400
|
// test with a server that wrongly responds 400
|
||||||
|
@ -201,7 +201,9 @@ impl<'a> Search<'a> {
|
|||||||
let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
|
let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
|
||||||
let _entered = span.enter();
|
let _entered = span.enter();
|
||||||
|
|
||||||
match embedder.embed_one(query) {
|
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
|
||||||
|
|
||||||
|
match embedder.embed_one(query, Some(deadline)) {
|
||||||
Ok(embedding) => embedding,
|
Ok(embedding) => embedding,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
tracing::error!(error=%error, "Embedding failed");
|
tracing::error!(error=%error, "Embedding failed");
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use arroy::distances::{BinaryQuantizedCosine, Cosine};
|
use arroy::distances::{BinaryQuantizedCosine, Cosine};
|
||||||
use arroy::ItemId;
|
use arroy::ItemId;
|
||||||
@ -594,18 +595,23 @@ impl Embedder {
|
|||||||
pub fn embed(
|
pub fn embed(
|
||||||
&self,
|
&self,
|
||||||
texts: Vec<String>,
|
texts: Vec<String>,
|
||||||
|
deadline: Option<Instant>,
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
Embedder::OpenAi(embedder) => embedder.embed(texts, deadline),
|
||||||
Embedder::Ollama(embedder) => embedder.embed(texts),
|
Embedder::Ollama(embedder) => embedder.embed(texts, deadline),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||||
Embedder::Rest(embedder) => embedder.embed(texts),
|
Embedder::Rest(embedder) => embedder.embed(texts, deadline),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
|
pub fn embed_one(
|
||||||
let mut embeddings = self.embed(vec![text])?;
|
&self,
|
||||||
|
text: String,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> std::result::Result<Embedding, EmbedError> {
|
||||||
|
let mut embeddings = self.embed(vec![text], deadline)?;
|
||||||
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
|
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
|
||||||
Ok(if embeddings.iter().nth(1).is_some() {
|
Ok(if embeddings.iter().nth(1).is_some() {
|
||||||
tracing::warn!("Ignoring embeddings past the first one in long search query");
|
tracing::warn!("Ignoring embeddings past the first one in long search query");
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
|
|
||||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
@ -75,8 +77,12 @@ impl Embedder {
|
|||||||
Ok(Self { rest_embedder })
|
Ok(Self { rest_embedder })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(
|
||||||
match self.rest_embedder.embed(texts) {
|
&self,
|
||||||
|
texts: Vec<String>,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
match self.rest_embedder.embed(texts, deadline) {
|
||||||
Ok(embeddings) => Ok(embeddings),
|
Ok(embeddings) => Ok(embeddings),
|
||||||
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
||||||
Err(EmbedError::ollama_model_not_found(error))
|
Err(EmbedError::ollama_model_not_found(error))
|
||||||
@ -92,7 +98,7 @@ impl Embedder {
|
|||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
|
||||||
})
|
})
|
||||||
.map_err(|error| EmbedError {
|
.map_err(|error| EmbedError {
|
||||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||||
|
|
||||||
@ -206,32 +208,40 @@ impl Embedder {
|
|||||||
Ok(Self { options, rest_embedder, tokenizer })
|
Ok(Self { options, rest_embedder, tokenizer })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(
|
||||||
match self.rest_embedder.embed_ref(&texts) {
|
&self,
|
||||||
|
texts: Vec<String>,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
match self.rest_embedder.embed_ref(&texts, deadline) {
|
||||||
Ok(embeddings) => Ok(embeddings),
|
Ok(embeddings) => Ok(embeddings),
|
||||||
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
|
||||||
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||||
self.try_embed_tokenized(&texts)
|
self.try_embed_tokenized(&texts, deadline)
|
||||||
}
|
}
|
||||||
Err(error) => Err(error),
|
Err(error) => Err(error),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
fn try_embed_tokenized(
|
||||||
|
&self,
|
||||||
|
text: &[String],
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||||
for text in text {
|
for text in text {
|
||||||
let max_token_count = self.options.embedding_model.max_token();
|
let max_token_count = self.options.embedding_model.max_token();
|
||||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||||
let len = encoded.len();
|
let len = encoded.len();
|
||||||
if len < max_token_count {
|
if len < max_token_count {
|
||||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let tokens = &encoded.as_slice()[0..max_token_count];
|
let tokens = &encoded.as_slice()[0..max_token_count];
|
||||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||||
|
|
||||||
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?;
|
||||||
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
||||||
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
|
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
|
||||||
})?;
|
})?;
|
||||||
@ -248,7 +258,7 @@ impl Embedder {
|
|||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
|
||||||
})
|
})
|
||||||
.map_err(|error| EmbedError {
|
.map_err(|error| EmbedError {
|
||||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
@ -154,19 +155,31 @@ impl Embedder {
|
|||||||
Ok(Self { data, dimensions, distribution: options.distribution })
|
Ok(Self { data, dimensions, distribution: options.distribution })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(
|
||||||
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
|
&self,
|
||||||
|
texts: Vec<String>,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
pub fn embed_ref<S>(
|
||||||
|
&self,
|
||||||
|
texts: &[S],
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
where
|
where
|
||||||
S: AsRef<str> + Serialize,
|
S: AsRef<str> + Serialize,
|
||||||
{
|
{
|
||||||
embed(&self.data, texts, texts.len(), Some(self.dimensions))
|
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
pub fn embed_tokens(
|
||||||
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
|
&self,
|
||||||
|
tokens: &[usize],
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
) -> Result<Embeddings<f32>, EmbedError> {
|
||||||
|
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
|
||||||
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(embeddings.pop().unwrap())
|
Ok(embeddings.pop().unwrap())
|
||||||
}
|
}
|
||||||
@ -178,7 +191,7 @@ impl Embedder {
|
|||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads
|
threads
|
||||||
.install(move || {
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
|
||||||
})
|
})
|
||||||
.map_err(|error| EmbedError {
|
.map_err(|error| EmbedError {
|
||||||
kind: EmbedErrorKind::PanicInThreadPool(error),
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
@ -207,7 +220,7 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||||
let v = embed(data, ["test"].as_slice(), 1, None)
|
let v = embed(data, ["test"].as_slice(), 1, None, None)
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(v.first().unwrap().dimension())
|
Ok(v.first().unwrap().dimension())
|
||||||
@ -218,6 +231,7 @@ fn embed<S>(
|
|||||||
inputs: &[S],
|
inputs: &[S],
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
expected_dimension: Option<usize>,
|
expected_dimension: Option<usize>,
|
||||||
|
deadline: Option<Instant>,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
where
|
where
|
||||||
S: Serialize,
|
S: Serialize,
|
||||||
@ -245,8 +259,19 @@ where
|
|||||||
}
|
}
|
||||||
Err(retry) => {
|
Err(retry) => {
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
tracing::warn!("Failed: {}", retry.error);
|
||||||
|
if let Some(deadline) = deadline {
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
if now > deadline {
|
||||||
|
tracing::warn!("Could not embed due to deadline");
|
||||||
|
return Err(retry.into_error());
|
||||||
|
}
|
||||||
|
|
||||||
|
let duration_to_deadline = deadline - now;
|
||||||
|
retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
|
||||||
|
} else {
|
||||||
retry.into_duration(attempt)
|
retry.into_duration(attempt)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||||
|
Loading…
Reference in New Issue
Block a user