Remove unwraps

This commit is contained in:
Louis Dureuil 2024-03-20 13:25:10 +01:00
parent b6b4b6bab7
commit f87747f4d3
No known key found for this signature in database
4 changed files with 40 additions and 13 deletions

View File

@ -242,11 +242,9 @@ fn send_original_documents_data(
let request_threads = rayon::ThreadPoolBuilder::new() let request_threads = rayon::ThreadPoolBuilder::new()
.num_threads(crate::vector::REQUEST_PARALLELISM) .num_threads(crate::vector::REQUEST_PARALLELISM)
.thread_name(|index| format!("embedding-request-{index}")) .thread_name(|index| format!("embedding-request-{index}"))
.build() .build()?;
.unwrap();
rayon::spawn(move || { rayon::spawn(move || {
/// FIXME: unwrap
for (name, (embedder, prompt)) in embedders { for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points( let result = extract_vector_points(
documents_chunk_cloned.clone(), documents_chunk_cloned.clone(),

View File

@ -52,8 +52,6 @@ pub enum EmbedErrorKind {
ModelForward(candle_core::Error), ModelForward(candle_core::Error),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String), ManualEmbed(String),
#[error("could not initialize asynchronous runtime: {0}")]
OpenAiRuntimeInit(std::io::Error),
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")]
OllamaModelNotFoundError(Option<String>), OllamaModelNotFoundError(Option<String>),
#[error("error deserialization the response body as JSON: {0}")] #[error("error deserialization the response body as JSON: {0}")]
@ -76,6 +74,10 @@ pub enum EmbedErrorKind {
RestOtherStatusCode(u16, Option<String>), RestOtherStatusCode(u16, Option<String>),
#[error("could not reach embedding server: {0}")] #[error("could not reach embedding server: {0}")]
RestNetwork(ureq::Transport), RestNetwork(ureq::Transport),
#[error("was expected '{}' to be an object in query '{0}'", .1.join("."))]
RestNotAnObject(serde_json::Value, Vec<String>),
#[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
OpenAiUnexpectedDimension(usize, usize),
} }
impl EmbedError { impl EmbedError {
@ -174,6 +176,20 @@ impl EmbedError {
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
} }
pub(crate) fn rest_not_an_object(
query: serde_json::Value,
input_path: Vec<String>,
) -> EmbedError {
Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User }
}
pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
Self {
kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got),
fault: FaultSource::Runtime,
}
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]

View File

@ -210,16 +210,19 @@ impl Embedder {
while tokens.len() > max_token_count { while tokens.len() > max_token_count {
let window = &tokens[..max_token_count]; let window = &tokens[..max_token_count];
let embedding = self.rest_embedder.embed_tokens(window)?; let embedding = self.rest_embedder.embed_tokens(window)?;
/// FIXME: unwrap embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
embeddings_for_prompt.append(embedding.into_inner()).unwrap(); EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
tokens = &tokens[max_token_count - OVERLAP_SIZE..]; tokens = &tokens[max_token_count - OVERLAP_SIZE..];
} }
// end of text // end of text
let embedding = self.rest_embedder.embed_tokens(tokens)?; let embedding = self.rest_embedder.embed_tokens(tokens)?;
/// FIXME: unwrap
embeddings_for_prompt.append(embedding.into_inner()).unwrap(); embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
})?;
all_embeddings.push(embeddings_for_prompt); all_embeddings.push(embeddings_for_prompt);
} }

View File

@ -189,19 +189,29 @@ where
[input] => { [input] => {
let mut body = options.query.clone(); let mut body = options.query.clone();
/// FIXME unwrap body.as_object_mut()
body.as_object_mut().unwrap().insert(input.clone(), input_value); .ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.insert(input.clone(), input_value);
body body
} }
[path @ .., input] => { [path @ .., input] => {
let mut body = options.query.clone(); let mut body = options.query.clone();
/// FIXME unwrap
let mut current_value = &mut body; let mut current_value = &mut body;
for component in path { for component in path {
current_value = current_value current_value = current_value
.as_object_mut() .as_object_mut()
.unwrap() .ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.entry(component.clone()) .entry(component.clone())
.or_insert(serde_json::json!({})); .or_insert(serde_json::json!({}));
} }