4382: Bring back changes from `release-v1.6.1` into `main` r=curquiza a=dureuill

Bring back changes from release-v1.6.1 into main

Supersedes https://github.com/meilisearch/meilisearch/pull/4380 and #4381 

Third time's the charm

Co-authored-by: curquiza <curquiza@users.noreply.github.com>
Co-authored-by: Louis Dureuil <louis@meilisearch.com>
Co-authored-by: Tamo <tamo@meilisearch.com>
Co-authored-by: Morgane Dubus <30866152+mdubus@users.noreply.github.com>
This commit is contained in:
meili-bors[bot] 2024-02-01 11:16:31 +00:00 committed by GitHub
commit ff76d8f21a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 332 additions and 72 deletions

32
Cargo.lock generated
View File

@ -494,7 +494,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]] [[package]]
name = "benchmarks" name = "benchmarks"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
@ -1476,7 +1476,7 @@ dependencies = [
[[package]] [[package]]
name = "dump" name = "dump"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"big_s", "big_s",
@ -1720,7 +1720,7 @@ dependencies = [
[[package]] [[package]]
name = "file-store" name = "file-store"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"faux", "faux",
"tempfile", "tempfile",
@ -1742,7 +1742,7 @@ dependencies = [
[[package]] [[package]]
name = "filter-parser" name = "filter-parser"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"insta", "insta",
"nom", "nom",
@ -1773,7 +1773,7 @@ dependencies = [
[[package]] [[package]]
name = "flatten-serde-json" name = "flatten-serde-json"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"criterion", "criterion",
"serde_json", "serde_json",
@ -1891,7 +1891,7 @@ dependencies = [
[[package]] [[package]]
name = "fuzzers" name = "fuzzers"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"arbitrary", "arbitrary",
"clap", "clap",
@ -2856,7 +2856,7 @@ checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d"
[[package]] [[package]]
name = "index-scheduler" name = "index-scheduler"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"big_s", "big_s",
@ -3043,7 +3043,7 @@ dependencies = [
[[package]] [[package]]
name = "json-depth-checker" name = "json-depth-checker"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"criterion", "criterion",
"serde_json", "serde_json",
@ -3555,7 +3555,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]] [[package]]
name = "meili-snap" name = "meili-snap"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"insta", "insta",
"md5", "md5",
@ -3564,7 +3564,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch" name = "meilisearch"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"actix-cors", "actix-cors",
"actix-http", "actix-http",
@ -3655,7 +3655,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch-auth" name = "meilisearch-auth"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"enum-iterator", "enum-iterator",
@ -3674,7 +3674,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch-types" name = "meilisearch-types"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"actix-web", "actix-web",
"anyhow", "anyhow",
@ -3704,7 +3704,7 @@ dependencies = [
[[package]] [[package]]
name = "meilitool" name = "meilitool"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap", "clap",
@ -3743,7 +3743,7 @@ dependencies = [
[[package]] [[package]]
name = "milli" name = "milli"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"arroy", "arroy",
"big_s", "big_s",
@ -4141,7 +4141,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]] [[package]]
name = "permissive-json-pointer" name = "permissive-json-pointer"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"big_s", "big_s",
"serde_json", "serde_json",
@ -6232,7 +6232,7 @@ dependencies = [
[[package]] [[package]]
name = "xtask" name = "xtask"
version = "1.6.0" version = "1.6.1"
dependencies = [ dependencies = [
"cargo_metadata", "cargo_metadata",
"clap", "clap",

View File

@ -20,7 +20,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "1.6.0" version = "1.6.1"
authors = ["Quentin de Quelen <quentin@dequelen.me>", "Clément Renault <clement@meilisearch.com>"] authors = ["Quentin de Quelen <quentin@dequelen.me>", "Clément Renault <clement@meilisearch.com>"]
description = "Meilisearch HTTP server" description = "Meilisearch HTTP server"
homepage = "https://meilisearch.com" homepage = "https://meilisearch.com"

View File

@ -154,5 +154,5 @@ greek = ["meilisearch-types/greek"]
khmer = ["meilisearch-types/khmer"] khmer = ["meilisearch-types/khmer"]
[package.metadata.mini-dashboard] [package.metadata.mini-dashboard]
assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.12/build.zip" assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.13/build.zip"
sha1 = "acfe9a018c93eb0604ea87ee87bff7df5474e18e" sha1 = "e20cc9b390003c6c844f4b8bcc5c5013191a77ff"

View File

@ -64,7 +64,7 @@ impl Display for Value {
write!( write!(
f, f,
"{}", "{}",
json_string!(self, { ".enqueuedAt" => "[date]", ".processedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" }) json_string!(self, { ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" })
) )
} }
} }

View File

@ -1760,6 +1760,181 @@ async fn add_documents_invalid_geo_field() {
"finishedAt": "[date]" "finishedAt": "[date]"
} }
"###); "###);
// The three next tests are related to #4333
// _geo has a lat and lng but set to `null`
let documents = json!([
{
"id": "12",
"_geo": { "lng": null, "lat": 67}
}
]);
let (response, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await;
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
@r###"
{
"uid": 14,
"indexUid": "test",
"status": "failed",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 0
},
"error": {
"message": "Could not parse longitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
"code": "invalid_document_geo_field",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
},
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
// _geo has a lat and lng but set to `null`
let documents = json!([
{
"id": "12",
"_geo": { "lng": 35, "lat": null }
}
]);
let (response, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await;
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
@r###"
{
"uid": 15,
"indexUid": "test",
"status": "failed",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 0
},
"error": {
"message": "Could not parse latitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.",
"code": "invalid_document_geo_field",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
},
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
// _geo has a lat and lng but set to `null`
let documents = json!([
{
"id": "13",
"_geo": { "lng": null, "lat": null }
}
]);
let (response, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await;
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
@r###"
{
"uid": 16,
"indexUid": "test",
"status": "failed",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 0
},
"error": {
"message": "Could not parse latitude nor longitude in the document with the id: `13`. Was expecting finite numbers but instead got `null` and `null`.",
"code": "invalid_document_geo_field",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
},
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
}
// Related to #4333
#[actix_rt::test]
async fn add_invalid_geo_and_then_settings() {
let server = Server::new().await;
let index = server.index("test");
index.create(Some("id")).await;
// _geo is not an object
let documents = json!([
{
"id": "11",
"_geo": { "lat": null, "lng": null },
}
]);
let (ret, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let ret = index.wait_task(ret.uid()).await;
snapshot!(ret, @r###"
{
"uid": 1,
"indexUid": "test",
"status": "succeeded",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 1
},
"error": null,
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
let (ret, code) = index.update_settings(json!({"sortableAttributes": ["_geo"]})).await;
snapshot!(code, @"202 Accepted");
let ret = index.wait_task(ret.uid()).await;
snapshot!(ret, @r###"
{
"uid": 2,
"indexUid": "test",
"status": "failed",
"type": "settingsUpdate",
"canceledBy": null,
"details": {
"sortableAttributes": [
"_geo"
]
},
"error": {
"message": "Could not parse latitude in the document with the id: `\"11\"`. Was expecting a finite number but instead got `null`.",
"code": "invalid_document_geo_field",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_document_geo_field"
},
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
} }
#[actix_rt::test] #[actix_rt::test]

View File

@ -87,6 +87,52 @@ async fn simple_search() {
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
} }
#[actix_rt::test]
async fn highlighter() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let (response, code) = index
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
"hybrid": {"semanticRatio": 0.2},
"attributesToHighlight": [
"desc"
],
"highlightPreTag": "**BEGIN**",
"highlightPostTag": "**END**"
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###);
let (response, code) = index
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
"hybrid": {"semanticRatio": 0.8},
"attributesToHighlight": [
"desc"
],
"highlightPreTag": "**BEGIN**",
"highlightPostTag": "**END**"
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_semanticScore":0.9472136}]"###);
// no highlighting on full semantic
let (response, code) = index
.search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0],
"hybrid": {"semanticRatio": 1.0},
"attributesToHighlight": [
"desc"
],
"highlightPreTag": "**BEGIN**",
"highlightPostTag": "**END**"
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}}]"###);
}
#[actix_rt::test] #[actix_rt::test]
async fn invalid_semantic_ratio() { async fn invalid_semantic_ratio() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -102,7 +102,7 @@ impl ScoreWithRatioResult {
} }
SearchResult { SearchResult {
matching_words: left.matching_words, matching_words: right.matching_words,
candidates: left.candidates | right.candidates, candidates: left.candidates | right.candidates,
documents_ids, documents_ids,
document_scores, document_scores,

View File

@ -34,7 +34,9 @@ pub fn extract_geo_points<R: io::Read + io::Seek>(
// since we only need the primary key when we throw an error // since we only need the primary key when we throw an error
// we create this getter to lazily get it when needed // we create this getter to lazily get it when needed
let document_id = || -> Value { let document_id = || -> Value {
let document_id = obkv.get(primary_key_id).unwrap(); let reader = KvReaderDelAdd::new(obkv.get(primary_key_id).unwrap());
let document_id =
reader.get(DelAdd::Deletion).or(reader.get(DelAdd::Addition)).unwrap();
serde_json::from_slice(document_id).unwrap() serde_json::from_slice(document_id).unwrap()
}; };

View File

@ -339,9 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
indexer: GrenadParameters, indexer: GrenadParameters,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
) -> Result<grenad::Reader<BufReader<File>>> { ) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
// docid, state with embedding // docid, state with embedding
@ -375,11 +373,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
current_chunk_ids.push(docid); current_chunk_ids.push(docid);
if chunks.len() == chunks.capacity() { if chunks.len() == chunks.capacity() {
let chunked_embeds = rt let chunked_embeds = embedder
.block_on( .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
@ -396,8 +391,8 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
// send last chunk // send last chunk
if !chunks.is_empty() { if !chunks.is_empty() {
let chunked_embeds = rt let chunked_embeds = embedder
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) .embed_chunks(std::mem::take(&mut chunks))
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids for (docid, embeddings) in chunks_ids
@ -410,15 +405,17 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
} }
if !current_chunk.is_empty() { if !current_chunk.is_empty() {
let embeds = rt let embeds = embedder
.block_on(embedder.embed(std::mem::take(&mut current_chunk))) .embed_chunks(vec![std::mem::take(&mut current_chunk)])
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
if let Some(embeds) = embeds.first() {
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
} }
} }
}
writer_into_reader(state_writer) writer_into_reader(state_writer)
} }

View File

@ -67,6 +67,10 @@ pub enum EmbedErrorKind {
OpenAiUnhandledStatusCode(u16), OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String), ManualEmbed(String),
#[error("could not initialize asynchronous runtime: {0}")]
OpenAiRuntimeInit(std::io::Error),
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
} }
impl EmbedError { impl EmbedError {
@ -117,6 +121,14 @@ impl EmbedError {
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
} }
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -183,10 +195,6 @@ impl NewEmbedderError {
} }
} }
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
} }
@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind {
#[error("loading model failed: {0}")] #[error("loading model failed: {0}")]
LoadModel(candle_core::Error), LoadModel(candle_core::Error),
// openai // openai
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
#[error("The API key passed to Authorization error was in an invalid format: {0}")] #[error("The API key passed to Authorization error was in an invalid format: {0}")]
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
} }

View File

@ -151,7 +151,8 @@ impl Embedder {
let token_ids = tokens let token_ids = tokens
.iter() .iter()
.map(|tokens| { .map(|tokens| {
let tokens = tokens.get_ids().to_vec(); let mut tokens = tokens.get_ids().to_vec();
tokens.truncate(512);
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
}) })
.collect::<Result<Vec<_>, EmbedError>>()?; .collect::<Result<Vec<_>, EmbedError>>()?;

View File

@ -163,18 +163,24 @@ impl Embedder {
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await, Embedder::OpenAi(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::UserProvided(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts),
} }
} }
pub async fn embed_chunks( /// # Panics
///
/// - if called from an asynchronous context
pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
} }
} }

View File

@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)] #[derive(Debug)]
pub struct Embedder { pub struct Embedder {
client: reqwest::Client, headers: reqwest::header::HeaderMap,
tokenizer: tiktoken_rs::CoreBPE, tokenizer: tiktoken_rs::CoreBPE,
options: EmbedderOptions, options: EmbedderOptions,
} }
@ -95,6 +95,13 @@ impl EmbedderOptions {
} }
impl Embedder { impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default(); let mut inferred_api_key = Default::default();
@ -111,25 +118,25 @@ impl Embedder {
reqwest::header::CONTENT_TYPE, reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"), reqwest::header::HeaderValue::from_static("application/json"),
); );
let client = reqwest::ClientBuilder::new()
.default_headers(headers)
.build()
.map_err(NewEmbedderError::openai_initialize_web_client)?;
// looking at the code it is very unclear that this can actually fail. // looking at the code it is very unclear that this can actually fail.
let tokenizer = tiktoken_rs::cl100k_base().unwrap(); let tokenizer = tiktoken_rs::cl100k_base().unwrap();
Ok(Self { options, client, tokenizer }) Ok(Self { options, headers, tokenizer })
} }
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub async fn embed(
&self,
texts: Vec<String>,
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut tokenized = false; let mut tokenized = false;
for attempt in 0..7 { for attempt in 0..7 {
let result = if tokenized { let result = if tokenized {
self.try_embed_tokenized(&texts).await self.try_embed_tokenized(&texts, client).await
} else { } else {
self.try_embed(&texts).await self.try_embed(&texts, client).await
}; };
let retry_duration = match result { let retry_duration = match result {
@ -145,9 +152,9 @@ impl Embedder {
} }
let result = if tokenized { let result = if tokenized {
self.try_embed_tokenized(&texts).await self.try_embed_tokenized(&texts, client).await
} else { } else {
self.try_embed(&texts).await self.try_embed(&texts, client).await
}; };
result.map_err(Retry::into_error) result.map_err(Retry::into_error)
@ -225,13 +232,13 @@ impl Embedder {
async fn try_embed<S: AsRef<str> + serde::Serialize>( async fn try_embed<S: AsRef<str> + serde::Serialize>(
&self, &self,
texts: &[S], texts: &[S],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> { ) -> Result<Vec<Embeddings<f32>>, Retry> {
for text in texts { for text in texts {
log::trace!("Received prompt: {}", text.as_ref()) log::trace!("Received prompt: {}", text.as_ref())
} }
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
let response = self let response = client
.client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
.send() .send()
@ -256,7 +263,11 @@ impl Embedder {
.collect()) .collect())
} }
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> { async fn try_embed_tokenized(
&self,
text: &[String],
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, Retry> {
pub const OVERLAP_SIZE: usize = 200; pub const OVERLAP_SIZE: usize = 200;
let mut all_embeddings = Vec::with_capacity(text.len()); let mut all_embeddings = Vec::with_capacity(text.len());
for text in text { for text in text {
@ -264,7 +275,7 @@ impl Embedder {
let encoded = self.tokenizer.encode_ordinary(text.as_str()); let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.try_embed(&[text]).await?); all_embeddings.append(&mut self.try_embed(&[text], client).await?);
continue; continue;
} }
@ -273,22 +284,26 @@ impl Embedder {
Embeddings::new(self.options.embedding_model.dimensions()); Embeddings::new(self.options.embedding_model.dimensions());
while tokens.len() > max_token_count { while tokens.len() > max_token_count {
let window = &tokens[..max_token_count]; let window = &tokens[..max_token_count];
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
tokens = &tokens[max_token_count - OVERLAP_SIZE..]; tokens = &tokens[max_token_count - OVERLAP_SIZE..];
} }
// end of text // end of text
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
all_embeddings.push(embeddings_for_prompt); all_embeddings.push(embeddings_for_prompt);
} }
Ok(all_embeddings) Ok(all_embeddings)
} }
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { async fn embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
for attempt in 0..9 { for attempt in 0..9 {
let duration = match self.try_embed_tokens(tokens).await { let duration = match self.try_embed_tokens(tokens, client).await {
Ok(embedding) => return Ok(embedding), Ok(embedding) => return Ok(embedding),
Err(retry) => retry.into_duration(attempt), Err(retry) => retry.into_duration(attempt),
} }
@ -297,14 +312,19 @@ impl Embedder {
tokio::time::sleep(duration).await; tokio::time::sleep(duration).await;
} }
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) self.try_embed_tokens(tokens, client)
.await
.map_err(|retry| Retry::give_up(retry.into_error()))
} }
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> { async fn try_embed_tokens(
&self,
tokens: &[usize],
client: &reqwest::Client,
) -> Result<Embedding, Retry> {
let request = let request =
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
let response = self let response = client
.client
.post(OPENAI_EMBEDDINGS_URL) .post(OPENAI_EMBEDDINGS_URL)
.json(&request) .json(&request)
.send() .send()
@ -322,12 +342,19 @@ impl Embedder {
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
} }
pub async fn embed_chunks( pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) let rt = tokio::runtime::Builder::new_current_thread()
.await .enable_io()
.enable_time()
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {