mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-11 07:28:56 +01:00
Update ollama and openai impls to use the rest embedder internally
This commit is contained in:
parent
8708cbef25
commit
ac52c857e8
@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
prompt_reader: grenad::Reader<R>,
|
prompt_reader: grenad::Reader<R>,
|
||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
embedder: Arc<Embedder>,
|
embedder: Arc<Embedder>,
|
||||||
|
request_threads: &rayon::ThreadPool,
|
||||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||||
puffin::profile_function!();
|
puffin::profile_function!();
|
||||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||||
@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
|
|
||||||
if chunks.len() == chunks.capacity() {
|
if chunks.len() == chunks.capacity() {
|
||||||
let chunked_embeds = embedder
|
let chunked_embeds = embedder
|
||||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
|
.embed_chunks(
|
||||||
|
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
|
||||||
|
request_threads,
|
||||||
|
)
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
// send last chunk
|
// send last chunk
|
||||||
if !chunks.is_empty() {
|
if !chunks.is_empty() {
|
||||||
let chunked_embeds = embedder
|
let chunked_embeds = embedder
|
||||||
.embed_chunks(std::mem::take(&mut chunks))
|
.embed_chunks(std::mem::take(&mut chunks), request_threads)
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
for (docid, embeddings) in chunks_ids
|
for (docid, embeddings) in chunks_ids
|
||||||
@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
|
|
||||||
if !current_chunk.is_empty() {
|
if !current_chunk.is_empty() {
|
||||||
let embeds = embedder
|
let embeds = embedder
|
||||||
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
|
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
|
||||||
.map_err(crate::vector::Error::from)
|
.map_err(crate::vector::Error::from)
|
||||||
.map_err(crate::Error::from)?;
|
.map_err(crate::Error::from)?;
|
||||||
|
|
||||||
|
@ -238,7 +238,15 @@ fn send_original_documents_data(
|
|||||||
|
|
||||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||||
|
|
||||||
|
let request_threads = rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
||||||
|
.thread_name(|index| format!("embedding-request-{index}"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
rayon::spawn(move || {
|
rayon::spawn(move || {
|
||||||
|
/// FIXME: unwrap
|
||||||
for (name, (embedder, prompt)) in embedders {
|
for (name, (embedder, prompt)) in embedders {
|
||||||
let result = extract_vector_points(
|
let result = extract_vector_points(
|
||||||
documents_chunk_cloned.clone(),
|
documents_chunk_cloned.clone(),
|
||||||
@ -249,7 +257,12 @@ fn send_original_documents_data(
|
|||||||
);
|
);
|
||||||
match result {
|
match result {
|
||||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||||
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
|
let embeddings = match extract_embeddings(
|
||||||
|
prompts,
|
||||||
|
indexer,
|
||||||
|
embedder.clone(),
|
||||||
|
&request_threads,
|
||||||
|
) {
|
||||||
Ok(results) => Some(results),
|
Ok(results) => Some(results),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||||
|
@ -2,9 +2,7 @@ use std::path::PathBuf;
|
|||||||
|
|
||||||
use hf_hub::api::sync::ApiError;
|
use hf_hub::api::sync::ApiError;
|
||||||
|
|
||||||
use super::ollama::OllamaError;
|
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::vector::openai::OpenAiError;
|
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("Error while generating embeddings: {inner}")]
|
#[error("Error while generating embeddings: {inner}")]
|
||||||
@ -52,43 +50,12 @@ pub enum EmbedErrorKind {
|
|||||||
TensorValue(candle_core::Error),
|
TensorValue(candle_core::Error),
|
||||||
#[error("could not run model: {0}")]
|
#[error("could not run model: {0}")]
|
||||||
ModelForward(candle_core::Error),
|
ModelForward(candle_core::Error),
|
||||||
#[error("could not reach OpenAI: {0}")]
|
|
||||||
OpenAiNetwork(ureq::Transport),
|
|
||||||
#[error("unexpected response from OpenAI: {0}")]
|
|
||||||
OpenAiUnexpected(ureq::Error),
|
|
||||||
#[error("could not authenticate against OpenAI: {0:?}")]
|
|
||||||
OpenAiAuth(Option<OpenAiError>),
|
|
||||||
#[error("sent too many requests to OpenAI: {0:?}")]
|
|
||||||
OpenAiTooManyRequests(Option<OpenAiError>),
|
|
||||||
#[error("received internal error from OpenAI: {0:?}")]
|
|
||||||
OpenAiInternalServerError(Option<OpenAiError>),
|
|
||||||
#[error("sent too many tokens in a request to OpenAI: {0:?}")]
|
|
||||||
OpenAiTooManyTokens(Option<OpenAiError>),
|
|
||||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
|
||||||
OpenAiUnhandledStatusCode(u16),
|
|
||||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||||
ManualEmbed(String),
|
ManualEmbed(String),
|
||||||
#[error("could not initialize asynchronous runtime: {0}")]
|
#[error("could not initialize asynchronous runtime: {0}")]
|
||||||
OpenAiRuntimeInit(std::io::Error),
|
OpenAiRuntimeInit(std::io::Error),
|
||||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")]
|
||||||
InitWebClient(reqwest::Error),
|
OllamaModelNotFoundError(Option<String>),
|
||||||
// Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends.
|
|
||||||
#[error("unexpected response from Ollama: {0}")]
|
|
||||||
OllamaUnexpected(reqwest::Error),
|
|
||||||
#[error("sent too many requests to Ollama: {0}")]
|
|
||||||
OllamaTooManyRequests(OllamaError),
|
|
||||||
#[error("received internal error from Ollama: {0}")]
|
|
||||||
OllamaInternalServerError(OllamaError),
|
|
||||||
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")]
|
|
||||||
OllamaModelNotFoundError(OllamaError),
|
|
||||||
#[error("received unhandled HTTP status code {0} from Ollama")]
|
|
||||||
OllamaUnhandledStatusCode(u16),
|
|
||||||
#[error("error serializing template context: {0}")]
|
|
||||||
RestTemplateContextSerialization(liquid::Error),
|
|
||||||
#[error(
|
|
||||||
"error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'"
|
|
||||||
)]
|
|
||||||
RestTemplateError(liquid::Error),
|
|
||||||
#[error("error deserialization the response body as JSON: {0}")]
|
#[error("error deserialization the response body as JSON: {0}")]
|
||||||
RestResponseDeserialization(std::io::Error),
|
RestResponseDeserialization(std::io::Error),
|
||||||
#[error("component `{0}` not found in path `{1}` in response: `{2}`")]
|
#[error("component `{0}` not found in path `{1}` in response: `{2}`")]
|
||||||
@ -128,77 +95,14 @@ impl EmbedError {
|
|||||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn openai_network(inner: ureq::Transport) -> Self {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn openai_unexpected(inner: ureq::Error) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn openai_auth_error(inner: Option<OpenAiError>) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn openai_too_many_requests(inner: Option<OpenAiError>) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn openai_internal_server_error(inner: Option<OpenAiError>) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn openai_too_many_tokens(inner: Option<OpenAiError>) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
|
pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
|
||||||
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
|
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError {
|
|
||||||
Self {
|
|
||||||
kind: EmbedErrorKind::RestTemplateContextSerialization(error),
|
|
||||||
fault: FaultSource::Bug,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError {
|
|
||||||
Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
|
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
|
||||||
Self {
|
Self {
|
||||||
kind: EmbedErrorKind::RestResponseDeserialization(error),
|
kind: EmbedErrorKind::RestResponseDeserialization(error),
|
||||||
@ -335,17 +239,6 @@ impl NewEmbedderError {
|
|||||||
fault: FaultSource::Runtime,
|
fault: FaultSource::Runtime,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
|
||||||
Self {
|
|
||||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
|
||||||
fault: FaultSource::User,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
|
||||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
@ -392,7 +285,4 @@ pub enum NewEmbedderErrorKind {
|
|||||||
CouldNotDetermineDimension(EmbedError),
|
CouldNotDetermineDimension(EmbedError),
|
||||||
#[error("loading model failed: {0}")]
|
#[error("loading model failed: {0}")]
|
||||||
LoadModel(candle_core::Error),
|
LoadModel(candle_core::Error),
|
||||||
// openai
|
|
||||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
|
||||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,8 @@ pub use self::error::Error;
|
|||||||
|
|
||||||
pub type Embedding = Vec<f32>;
|
pub type Embedding = Vec<f32>;
|
||||||
|
|
||||||
|
pub const REQUEST_PARALLELISM: usize = 40;
|
||||||
|
|
||||||
/// One or multiple embeddings stored consecutively in a flat vector.
|
/// One or multiple embeddings stored consecutively in a flat vector.
|
||||||
pub struct Embeddings<F> {
|
pub struct Embeddings<F> {
|
||||||
data: Vec<F>,
|
data: Vec<F>,
|
||||||
@ -99,7 +101,7 @@ pub enum Embedder {
|
|||||||
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
||||||
HuggingFace(hf::Embedder),
|
HuggingFace(hf::Embedder),
|
||||||
/// An embedder based on making embedding queries against the OpenAI API.
|
/// An embedder based on making embedding queries against the OpenAI API.
|
||||||
OpenAi(openai::sync::Embedder),
|
OpenAi(openai::Embedder),
|
||||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||||
UserProvided(manual::Embedder),
|
UserProvided(manual::Embedder),
|
||||||
Ollama(ollama::Embedder),
|
Ollama(ollama::Embedder),
|
||||||
@ -202,7 +204,7 @@ impl Embedder {
|
|||||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||||
Ok(match options {
|
Ok(match options {
|
||||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?),
|
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||||
EmbedderOptions::UserProvided(options) => {
|
EmbedderOptions::UserProvided(options) => {
|
||||||
Self::UserProvided(manual::Embedder::new(options))
|
Self::UserProvided(manual::Embedder::new(options))
|
||||||
@ -213,17 +215,14 @@ impl Embedder {
|
|||||||
/// Embed one or multiple texts.
|
/// Embed one or multiple texts.
|
||||||
///
|
///
|
||||||
/// Each text can be embedded as one or multiple embeddings.
|
/// Each text can be embedded as one or multiple embeddings.
|
||||||
pub async fn embed(
|
pub fn embed(
|
||||||
&self,
|
&self,
|
||||||
texts: Vec<String>,
|
texts: Vec<String>,
|
||||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||||
Embedder::Ollama(embedder) => {
|
Embedder::Ollama(embedder) => embedder.embed(texts),
|
||||||
let client = embedder.new_client()?;
|
|
||||||
embedder.embed(texts, &client).await
|
|
||||||
}
|
|
||||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,18 +230,15 @@ impl Embedder {
|
|||||||
/// Embed multiple chunks of texts.
|
/// Embed multiple chunks of texts.
|
||||||
///
|
///
|
||||||
/// Each chunk is composed of one or multiple texts.
|
/// Each chunk is composed of one or multiple texts.
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
///
|
|
||||||
/// - if called from an asynchronous context
|
|
||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
|
threads: &rayon::ThreadPool,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,293 +1,94 @@
|
|||||||
// Copied from "openai.rs" with the sections I actually understand changed for Ollama.
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
// The common components of the Ollama and OpenAI interfaces might need to be extracted.
|
|
||||||
|
|
||||||
use std::fmt::Display;
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use reqwest::StatusCode;
|
use super::{DistributionShift, Embeddings};
|
||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
|
||||||
use super::openai::Retry;
|
|
||||||
use super::{DistributionShift, Embedding, Embeddings};
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
headers: reqwest::header::HeaderMap,
|
rest_embedder: RestEmbedder,
|
||||||
options: EmbedderOptions,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub embedding_model: EmbeddingModel,
|
pub embedding_model: String,
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(
|
|
||||||
Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr,
|
|
||||||
)]
|
|
||||||
#[deserr(deny_unknown_fields)]
|
|
||||||
pub struct EmbeddingModel {
|
|
||||||
name: String,
|
|
||||||
dimensions: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
struct OllamaRequest<'a> {
|
|
||||||
model: &'a str,
|
|
||||||
prompt: &'a str,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
|
||||||
struct OllamaResponse {
|
|
||||||
embedding: Embedding,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize)]
|
|
||||||
pub struct OllamaError {
|
|
||||||
error: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingModel {
|
|
||||||
pub fn max_token(&self) -> usize {
|
|
||||||
// this might not be the same for all models
|
|
||||||
8192
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn default_dimensions(&self) -> usize {
|
|
||||||
// Dimensions for nomic-embed-text
|
|
||||||
768
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn name(&self) -> String {
|
|
||||||
self.name.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_name(name: &str) -> Self {
|
|
||||||
Self { name: name.to_string(), dimensions: 0 }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn supports_overriding_dimensions(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for EmbeddingModel {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self { name: "nomic-embed-text".to_string(), dimensions: 0 }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn with_default_model() -> Self {
|
pub fn with_default_model() -> Self {
|
||||||
Self { embedding_model: Default::default() }
|
Self { embedding_model: "nomic-embed-text".into() }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self {
|
pub fn with_embedding_model(embedding_model: String) -> Self {
|
||||||
Self { embedding_model }
|
Self { embedding_model }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
|
||||||
reqwest::ClientBuilder::new()
|
|
||||||
.default_headers(self.headers.clone())
|
|
||||||
.build()
|
|
||||||
.map_err(EmbedError::openai_initialize_web_client)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let model = options.embedding_model.as_str();
|
||||||
headers.insert(
|
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
|
||||||
reqwest::header::CONTENT_TYPE,
|
api_key: None,
|
||||||
reqwest::header::HeaderValue::from_static("application/json"),
|
distribution: None,
|
||||||
);
|
dimensions: None,
|
||||||
|
url: get_ollama_path(),
|
||||||
let mut embedder = Self { options, headers };
|
query: serde_json::json!({
|
||||||
|
"model": model,
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
}),
|
||||||
.enable_io()
|
input_field: vec!["prompt".to_owned()],
|
||||||
.enable_time()
|
path_to_embeddings: Default::default(),
|
||||||
.build()
|
embedding_object: vec!["embedding".to_owned()],
|
||||||
.map_err(EmbedError::openai_runtime_init)
|
input_type: super::rest::InputType::Text,
|
||||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
}) {
|
||||||
|
Ok(embedder) => embedder,
|
||||||
// Get dimensions from Ollama
|
Err(NewEmbedderError {
|
||||||
let request =
|
kind:
|
||||||
OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" };
|
NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
|
||||||
// TODO: Refactor into shared error type
|
kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
|
||||||
let client = embedder
|
fault: _,
|
||||||
.new_client()
|
}),
|
||||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
fault: _,
|
||||||
|
}) => {
|
||||||
rt.block_on(async move {
|
return Err(NewEmbedderError::could_not_determine_dimension(
|
||||||
let response = client
|
EmbedError::ollama_model_not_found(error),
|
||||||
.post(get_ollama_path())
|
))
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::ollama_unexpected)
|
|
||||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
|
||||||
|
|
||||||
// Process error in case model not found
|
|
||||||
let response = Self::check_response(response).await.map_err(|_err| {
|
|
||||||
let e = EmbedError::ollama_model_not_found(OllamaError {
|
|
||||||
error: format!("model: {}", embedder.options.embedding_model.name()),
|
|
||||||
});
|
|
||||||
NewEmbedderError::ollama_could_not_determine_dimension(e)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let response: OllamaResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::ollama_unexpected)
|
|
||||||
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
|
||||||
|
|
||||||
let embedding = Embeddings::from_single_embedding(response.embedding);
|
|
||||||
|
|
||||||
embedder.options.embedding_model.dimensions = embedding.dimension();
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
"ollama model {} with dimensionality {} added",
|
|
||||||
embedder.options.embedding_model.name(),
|
|
||||||
embedding.dimension()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(embedder)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
|
||||||
if !response.status().is_success() {
|
|
||||||
// Not the same number of possible error cases covered as with OpenAI.
|
|
||||||
match response.status() {
|
|
||||||
StatusCode::TOO_MANY_REQUESTS => {
|
|
||||||
let error_response: OllamaError = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::ollama_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests(
|
|
||||||
OllamaError { error: error_response.error },
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE => {
|
|
||||||
let error_response: OllamaError = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::ollama_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
return Err(Retry::retry_later(EmbedError::ollama_internal_server_error(
|
|
||||||
OllamaError { error: error_response.error },
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
StatusCode::NOT_FOUND => {
|
|
||||||
let error_response: OllamaError = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::ollama_unexpected)
|
|
||||||
.map_err(Retry::give_up)?;
|
|
||||||
|
|
||||||
return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError {
|
|
||||||
error: error_response.error,
|
|
||||||
})));
|
|
||||||
}
|
|
||||||
code => {
|
|
||||||
return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code(
|
|
||||||
code.as_u16(),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
Err(error) => return Err(error),
|
||||||
Ok(response)
|
};
|
||||||
|
|
||||||
|
Ok(Self { rest_embedder })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed(
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
&self,
|
match self.rest_embedder.embed(texts) {
|
||||||
texts: Vec<String>,
|
Ok(embeddings) => Ok(embeddings),
|
||||||
client: &reqwest::Client,
|
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
Err(EmbedError::ollama_model_not_found(error))
|
||||||
// Ollama only embedds one document at a time.
|
|
||||||
let mut results = Vec::with_capacity(texts.len());
|
|
||||||
|
|
||||||
// The retry loop is inside the texts loop, might have to switch that around
|
|
||||||
for text in texts {
|
|
||||||
// Retries copied from openai.rs
|
|
||||||
for attempt in 0..7 {
|
|
||||||
let retry_duration = match self.try_embed(&text, client).await {
|
|
||||||
Ok(result) => {
|
|
||||||
results.push(result);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(retry) => {
|
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
|
||||||
retry.into_duration(attempt)
|
|
||||||
}
|
|
||||||
}?;
|
|
||||||
tracing::warn!(
|
|
||||||
"Attempt #{}, retrying after {}ms.",
|
|
||||||
attempt,
|
|
||||||
retry_duration.as_millis()
|
|
||||||
);
|
|
||||||
tokio::time::sleep(retry_duration).await;
|
|
||||||
}
|
}
|
||||||
|
Err(error) => Err(error),
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn try_embed(
|
|
||||||
&self,
|
|
||||||
text: &str,
|
|
||||||
client: &reqwest::Client,
|
|
||||||
) -> Result<Embeddings<f32>, Retry> {
|
|
||||||
let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text };
|
|
||||||
let response = client
|
|
||||||
.post(get_ollama_path())
|
|
||||||
.json(&request)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_network)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
let response = Self::check_response(response).await?;
|
|
||||||
|
|
||||||
let response: OllamaResponse = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
tracing::trace!("response: {:?}", response.embedding);
|
|
||||||
|
|
||||||
let embedding = Embeddings::from_single_embedding(response.embedding);
|
|
||||||
Ok(embedding)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
|
threads: &rayon::ThreadPool,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
threads.install(move || {
|
||||||
.enable_io()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
.enable_time()
|
})
|
||||||
.build()
|
|
||||||
.map_err(EmbedError::openai_runtime_init)?;
|
|
||||||
let client = self.new_client()?;
|
|
||||||
rt.block_on(futures::future::try_join_all(
|
|
||||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Defaults copied from openai.rs
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
10
|
self.rest_embedder.chunk_count_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
10
|
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
pub fn dimensions(&self) -> usize {
|
||||||
self.options.embedding_model.dimensions
|
self.rest_embedder.dimensions()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
@ -295,12 +96,6 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for OllamaError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{}", self.error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_ollama_path() -> String {
|
fn get_ollama_path() -> String {
|
||||||
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
||||||
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use std::fmt::Display;
|
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::{DistributionShift, Embedding, Embeddings};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
|
use super::{DistributionShift, Embeddings};
|
||||||
|
use crate::vector::error::EmbedErrorKind;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
@ -12,6 +12,32 @@ pub struct EmbedderOptions {
|
|||||||
pub dimensions: Option<usize>,
|
pub dimensions: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl EmbedderOptions {
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
if self.embedding_model.supports_overriding_dimensions() {
|
||||||
|
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
|
||||||
|
} else {
|
||||||
|
self.embedding_model.default_dimensions()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn query(&self) -> serde_json::Value {
|
||||||
|
let model = self.embedding_model.name();
|
||||||
|
|
||||||
|
let mut query = serde_json::json!({
|
||||||
|
"model": model,
|
||||||
|
});
|
||||||
|
|
||||||
|
if self.embedding_model.supports_overriding_dimensions() {
|
||||||
|
if let Some(dimensions) = self.dimensions {
|
||||||
|
query["dimensions"] = dimensions.into();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
Debug,
|
Debug,
|
||||||
Clone,
|
Clone,
|
||||||
@ -117,364 +143,112 @@ impl EmbedderOptions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrying in case of failure
|
|
||||||
|
|
||||||
pub struct Retry {
|
|
||||||
pub error: EmbedError,
|
|
||||||
strategy: RetryStrategy,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub enum RetryStrategy {
|
|
||||||
GiveUp,
|
|
||||||
Retry,
|
|
||||||
RetryTokenized,
|
|
||||||
RetryAfterRateLimit,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Retry {
|
|
||||||
pub fn give_up(error: EmbedError) -> Self {
|
|
||||||
Self { error, strategy: RetryStrategy::GiveUp }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn retry_later(error: EmbedError) -> Self {
|
|
||||||
Self { error, strategy: RetryStrategy::Retry }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn retry_tokenized(error: EmbedError) -> Self {
|
|
||||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rate_limited(error: EmbedError) -> Self {
|
|
||||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
|
|
||||||
match self.strategy {
|
|
||||||
RetryStrategy::GiveUp => Err(self.error),
|
|
||||||
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
|
|
||||||
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
|
|
||||||
RetryStrategy::RetryAfterRateLimit => {
|
|
||||||
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn must_tokenize(&self) -> bool {
|
|
||||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_error(self) -> EmbedError {
|
|
||||||
self.error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// openai api structs
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
|
|
||||||
model: &'a str,
|
|
||||||
input: &'a [S],
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
dimensions: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
struct OpenAiTokensRequest<'a> {
|
|
||||||
model: &'a str,
|
|
||||||
input: &'a [usize],
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
dimensions: Option<usize>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAiResponse {
|
|
||||||
data: Vec<OpenAiEmbedding>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAiErrorResponse {
|
|
||||||
error: OpenAiError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct OpenAiError {
|
|
||||||
message: String,
|
|
||||||
// type: String,
|
|
||||||
code: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for OpenAiError {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match &self.code {
|
|
||||||
Some(code) => write!(f, "{} ({})", self.message, code),
|
|
||||||
None => write!(f, "{}", self.message),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct OpenAiEmbedding {
|
|
||||||
embedding: Embedding,
|
|
||||||
// object: String,
|
|
||||||
// index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn infer_api_key() -> String {
|
fn infer_api_key() -> String {
|
||||||
std::env::var("MEILI_OPENAI_API_KEY")
|
std::env::var("MEILI_OPENAI_API_KEY")
|
||||||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod sync {
|
#[derive(Debug)]
|
||||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
pub struct Embedder {
|
||||||
|
tokenizer: tiktoken_rs::CoreBPE,
|
||||||
|
rest_embedder: RestEmbedder,
|
||||||
|
options: EmbedderOptions,
|
||||||
|
}
|
||||||
|
|
||||||
use super::{
|
impl Embedder {
|
||||||
EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest,
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL,
|
let mut inferred_api_key = Default::default();
|
||||||
};
|
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||||
use crate::vector::DistributionShift;
|
inferred_api_key = infer_api_key();
|
||||||
|
&inferred_api_key
|
||||||
|
});
|
||||||
|
|
||||||
const REQUEST_PARALLELISM: usize = 10;
|
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
|
||||||
|
api_key: Some(api_key.clone()),
|
||||||
|
distribution: options.embedding_model.distribution(),
|
||||||
|
dimensions: Some(options.dimensions()),
|
||||||
|
url: OPENAI_EMBEDDINGS_URL.to_owned(),
|
||||||
|
query: options.query(),
|
||||||
|
input_field: vec!["input".to_owned()],
|
||||||
|
input_type: crate::vector::rest::InputType::TextArray,
|
||||||
|
path_to_embeddings: vec!["data".to_owned()],
|
||||||
|
embedding_object: vec!["embedding".to_owned()],
|
||||||
|
})?;
|
||||||
|
|
||||||
#[derive(Debug)]
|
// looking at the code it is very unclear that this can actually fail.
|
||||||
pub struct Embedder {
|
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||||
tokenizer: tiktoken_rs::CoreBPE,
|
|
||||||
options: super::EmbedderOptions,
|
Ok(Self { options, rest_embedder, tokenizer })
|
||||||
bearer: String,
|
|
||||||
threads: rayon::ThreadPool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
pub fn new(options: super::EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
match self.rest_embedder.embed_ref(&texts) {
|
||||||
let mut inferred_api_key = Default::default();
|
Ok(embeddings) => Ok(embeddings),
|
||||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => {
|
||||||
inferred_api_key = super::infer_api_key();
|
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||||
&inferred_api_key
|
self.try_embed_tokenized(&texts)
|
||||||
});
|
}
|
||||||
let bearer = format!("Bearer {api_key}");
|
Err(error) => Err(error),
|
||||||
|
|
||||||
// looking at the code it is very unclear that this can actually fail.
|
|
||||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
|
||||||
|
|
||||||
// FIXME: unwrap
|
|
||||||
let threads = rayon::ThreadPoolBuilder::new()
|
|
||||||
.num_threads(REQUEST_PARALLELISM)
|
|
||||||
.thread_name(|index| format!("embedder-chunk-{index}"))
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Ok(Self { options, bearer, tokenizer, threads })
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
let mut tokenized = false;
|
pub const OVERLAP_SIZE: usize = 200;
|
||||||
|
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||||
let client = ureq::agent();
|
for text in text {
|
||||||
|
let max_token_count = self.options.embedding_model.max_token();
|
||||||
for attempt in 0..7 {
|
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||||
let result = if tokenized {
|
let len = encoded.len();
|
||||||
self.try_embed_tokenized(&texts, &client)
|
if len < max_token_count {
|
||||||
} else {
|
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
||||||
self.try_embed(&texts, &client)
|
continue;
|
||||||
};
|
|
||||||
|
|
||||||
let retry_duration = match result {
|
|
||||||
Ok(embeddings) => return Ok(embeddings),
|
|
||||||
Err(retry) => {
|
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
|
||||||
tokenized |= retry.must_tokenize();
|
|
||||||
retry.into_duration(attempt)
|
|
||||||
}
|
|
||||||
}?;
|
|
||||||
|
|
||||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
|
||||||
tracing::warn!(
|
|
||||||
"Attempt #{}, retrying after {}ms.",
|
|
||||||
attempt,
|
|
||||||
retry_duration.as_millis()
|
|
||||||
);
|
|
||||||
std::thread::sleep(retry_duration);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = if tokenized {
|
let mut tokens = encoded.as_slice();
|
||||||
self.try_embed_tokenized(&texts, &client)
|
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||||
} else {
|
while tokens.len() > max_token_count {
|
||||||
self.try_embed(&texts, &client)
|
let window = &tokens[..max_token_count];
|
||||||
};
|
let embedding = self.rest_embedder.embed_tokens(window)?;
|
||||||
|
/// FIXME: unwrap
|
||||||
|
embeddings_for_prompt.append(embedding.into_inner()).unwrap();
|
||||||
|
|
||||||
result.map_err(Retry::into_error)
|
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||||
}
|
|
||||||
|
|
||||||
fn check_response(
|
|
||||||
response: Result<ureq::Response, ureq::Error>,
|
|
||||||
) -> Result<ureq::Response, Retry> {
|
|
||||||
match response {
|
|
||||||
Ok(response) => Ok(response),
|
|
||||||
Err(ureq::Error::Status(code, response)) => {
|
|
||||||
let error_response: Option<OpenAiErrorResponse> = response.into_json().ok();
|
|
||||||
let error = error_response.map(|response| response.error);
|
|
||||||
Err(match code {
|
|
||||||
401 => Retry::give_up(EmbedError::openai_auth_error(error)),
|
|
||||||
429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)),
|
|
||||||
400 => {
|
|
||||||
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
|
||||||
|
|
||||||
Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error))
|
|
||||||
}
|
|
||||||
500..=599 => {
|
|
||||||
Retry::retry_later(EmbedError::openai_internal_server_error(error))
|
|
||||||
}
|
|
||||||
x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(ureq::Error::Transport(transport)) => {
|
|
||||||
Err(Retry::retry_later(EmbedError::openai_network(transport)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_embed<S: AsRef<str> + serde::Serialize>(
|
|
||||||
&self,
|
|
||||||
texts: &[S],
|
|
||||||
client: &ureq::Agent,
|
|
||||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
|
||||||
for text in texts {
|
|
||||||
tracing::trace!("Received prompt: {}", text.as_ref())
|
|
||||||
}
|
|
||||||
let request = OpenAiRequest {
|
|
||||||
model: self.options.embedding_model.name(),
|
|
||||||
input: texts,
|
|
||||||
dimensions: self.overriden_dimensions(),
|
|
||||||
};
|
|
||||||
let response = client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
|
||||||
.set("Authorization", &self.bearer)
|
|
||||||
.send_json(&request);
|
|
||||||
|
|
||||||
let response = Self::check_response(response)?;
|
|
||||||
|
|
||||||
let response: OpenAiResponse = response
|
|
||||||
.into_json()
|
|
||||||
.map_err(EmbedError::openai_unexpected)
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
tracing::trace!("response: {:?}", response.data);
|
|
||||||
|
|
||||||
Ok(response
|
|
||||||
.data
|
|
||||||
.into_iter()
|
|
||||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_embed_tokenized(
|
|
||||||
&self,
|
|
||||||
text: &[String],
|
|
||||||
client: &ureq::Agent,
|
|
||||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
|
||||||
pub const OVERLAP_SIZE: usize = 200;
|
|
||||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
|
||||||
for text in text {
|
|
||||||
let max_token_count = self.options.embedding_model.max_token();
|
|
||||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
|
||||||
let len = encoded.len();
|
|
||||||
if len < max_token_count {
|
|
||||||
all_embeddings.append(&mut self.try_embed(&[text], client)?);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut tokens = encoded.as_slice();
|
|
||||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
|
||||||
while tokens.len() > max_token_count {
|
|
||||||
let window = &tokens[..max_token_count];
|
|
||||||
embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap();
|
|
||||||
|
|
||||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
|
||||||
}
|
|
||||||
|
|
||||||
// end of text
|
|
||||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap();
|
|
||||||
|
|
||||||
all_embeddings.push(embeddings_for_prompt);
|
|
||||||
}
|
|
||||||
Ok(all_embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result<Embedding, Retry> {
|
|
||||||
for attempt in 0..9 {
|
|
||||||
let duration = match self.try_embed_tokens(tokens, client) {
|
|
||||||
Ok(embedding) => return Ok(embedding),
|
|
||||||
Err(retry) => retry.into_duration(attempt),
|
|
||||||
}
|
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
std::thread::sleep(duration);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.try_embed_tokens(tokens, client)
|
// end of text
|
||||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
||||||
|
/// FIXME: unwrap
|
||||||
|
embeddings_for_prompt.append(embedding.into_inner()).unwrap();
|
||||||
|
|
||||||
|
all_embeddings.push(embeddings_for_prompt);
|
||||||
}
|
}
|
||||||
|
Ok(all_embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
fn try_embed_tokens(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
tokens: &[usize],
|
text_chunks: Vec<Vec<String>>,
|
||||||
client: &ureq::Agent,
|
threads: &rayon::ThreadPool,
|
||||||
) -> Result<Embedding, Retry> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
let request = OpenAiTokensRequest {
|
threads.install(move || {
|
||||||
model: self.options.embedding_model.name(),
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
input: tokens,
|
})
|
||||||
dimensions: self.overriden_dimensions(),
|
}
|
||||||
};
|
|
||||||
let response = client
|
|
||||||
.post(OPENAI_EMBEDDINGS_URL)
|
|
||||||
.set("Authorization", &self.bearer)
|
|
||||||
.send_json(&request);
|
|
||||||
|
|
||||||
let response = Self::check_response(response)?;
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
self.rest_embedder.chunk_count_hint()
|
||||||
|
}
|
||||||
|
|
||||||
let mut response: OpenAiResponse = response
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
.into_json()
|
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||||
.map_err(EmbedError::openai_unexpected)
|
}
|
||||||
.map_err(Retry::retry_later)?;
|
|
||||||
|
|
||||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
pub fn dimensions(&self) -> usize {
|
||||||
}
|
self.options.dimensions()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn embed_chunks(
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
&self,
|
self.options.embedding_model.distribution()
|
||||||
text_chunks: Vec<Vec<String>>,
|
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
|
||||||
self.threads
|
|
||||||
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
|
||||||
10
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
|
||||||
10
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
|
||||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
|
||||||
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
|
||||||
} else {
|
|
||||||
self.options.embedding_model.default_dimensions()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
|
||||||
self.options.embedding_model.distribution()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn overriden_dimensions(&self) -> Option<usize> {
|
|
||||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
|
||||||
self.options.dimensions
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,62 @@
|
|||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
use super::openai::Retry;
|
use super::{
|
||||||
use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError};
|
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||||
use crate::VectorOrArrayOfVectors;
|
};
|
||||||
|
|
||||||
|
// retrying in case of failure
|
||||||
|
|
||||||
|
pub struct Retry {
|
||||||
|
pub error: EmbedError,
|
||||||
|
strategy: RetryStrategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum RetryStrategy {
|
||||||
|
GiveUp,
|
||||||
|
Retry,
|
||||||
|
RetryTokenized,
|
||||||
|
RetryAfterRateLimit,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Retry {
|
||||||
|
pub fn give_up(error: EmbedError) -> Self {
|
||||||
|
Self { error, strategy: RetryStrategy::GiveUp }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn retry_later(error: EmbedError) -> Self {
|
||||||
|
Self { error, strategy: RetryStrategy::Retry }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn retry_tokenized(error: EmbedError) -> Self {
|
||||||
|
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rate_limited(error: EmbedError) -> Self {
|
||||||
|
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
|
||||||
|
match self.strategy {
|
||||||
|
RetryStrategy::GiveUp => Err(self.error),
|
||||||
|
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
|
||||||
|
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
|
||||||
|
RetryStrategy::RetryAfterRateLimit => {
|
||||||
|
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn must_tokenize(&self) -> bool {
|
||||||
|
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_error(self) -> EmbedError {
|
||||||
|
self.error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
client: ureq::Agent,
|
client: ureq::Agent,
|
||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
@ -11,20 +64,35 @@ pub struct Embedder {
|
|||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
distribution: Option<DistributionShift>,
|
pub distribution: Option<DistributionShift>,
|
||||||
dimensions: Option<usize>,
|
pub dimensions: Option<usize>,
|
||||||
url: String,
|
pub url: String,
|
||||||
query: liquid::Template,
|
pub query: serde_json::Value,
|
||||||
response_field: Vec<String>,
|
pub input_field: Vec<String>,
|
||||||
|
// path to the array of embeddings
|
||||||
|
pub path_to_embeddings: Vec<String>,
|
||||||
|
// shape of a single embedding
|
||||||
|
pub embedding_object: Vec<String>,
|
||||||
|
pub input_type: InputType,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum InputType {
|
||||||
|
Text,
|
||||||
|
TextArray,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}"));
|
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
||||||
|
|
||||||
let client = ureq::agent();
|
let client = ureq::AgentBuilder::new()
|
||||||
|
.max_idle_connections(REQUEST_PARALLELISM * 2)
|
||||||
|
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
||||||
|
.build();
|
||||||
|
|
||||||
let dimensions = if let Some(dimensions) = options.dimensions {
|
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||||
dimensions
|
dimensions
|
||||||
@ -36,7 +104,20 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice())
|
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
|
where
|
||||||
|
S: AsRef<str> + Serialize,
|
||||||
|
{
|
||||||
|
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
||||||
|
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?;
|
||||||
|
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||||
|
Ok(embeddings.pop().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
@ -44,17 +125,20 @@ impl Embedder {
|
|||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &rayon::ThreadPool,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads
|
threads.install(move || {
|
||||||
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
.collect()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
10
|
super::REQUEST_PARALLELISM
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
10
|
match self.options.input_type {
|
||||||
|
InputType::Text => 1,
|
||||||
|
InputType::TextArray => 10,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dimensions(&self) -> usize {
|
pub fn dimensions(&self) -> usize {
|
||||||
@ -71,9 +155,9 @@ fn infer_dimensions(
|
|||||||
options: &EmbedderOptions,
|
options: &EmbedderOptions,
|
||||||
bearer: Option<&str>,
|
bearer: Option<&str>,
|
||||||
) -> Result<usize, NewEmbedderError> {
|
) -> Result<usize, NewEmbedderError> {
|
||||||
let v = embed(client, options, bearer, ["test"].as_slice())
|
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
// unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(v.first().unwrap().dimension())
|
Ok(v.first().unwrap().dimension())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,33 +166,57 @@ fn embed<S>(
|
|||||||
options: &EmbedderOptions,
|
options: &EmbedderOptions,
|
||||||
bearer: Option<&str>,
|
bearer: Option<&str>,
|
||||||
inputs: &[S],
|
inputs: &[S],
|
||||||
|
expected_count: usize,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
where
|
where
|
||||||
S: serde::Serialize,
|
S: Serialize,
|
||||||
{
|
{
|
||||||
let request = client.post(&options.url);
|
let request = client.post(&options.url);
|
||||||
let request =
|
let request =
|
||||||
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
|
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
|
||||||
let request = request.set("Content-Type", "application/json");
|
let request = request.set("Content-Type", "application/json");
|
||||||
|
|
||||||
let body = options
|
let input_value = match options.input_type {
|
||||||
.query
|
InputType::Text => serde_json::json!(inputs.first()),
|
||||||
.render(
|
InputType::TextArray => serde_json::json!(inputs),
|
||||||
&liquid::to_object(&serde_json::json!({
|
};
|
||||||
"input": inputs,
|
|
||||||
}))
|
let body = match options.input_field.as_slice() {
|
||||||
.map_err(EmbedError::rest_template_context_serialization)?,
|
[] => {
|
||||||
)
|
// inject input in body
|
||||||
.map_err(EmbedError::rest_template_render)?;
|
input_value
|
||||||
|
}
|
||||||
|
[input] => {
|
||||||
|
let mut body = options.query.clone();
|
||||||
|
|
||||||
|
/// FIXME unwrap
|
||||||
|
body.as_object_mut().unwrap().insert(input.clone(), input_value);
|
||||||
|
body
|
||||||
|
}
|
||||||
|
[path @ .., input] => {
|
||||||
|
let mut body = options.query.clone();
|
||||||
|
|
||||||
|
/// FIXME unwrap
|
||||||
|
let mut current_value = &mut body;
|
||||||
|
for component in path {
|
||||||
|
current_value = current_value
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.entry(component.clone())
|
||||||
|
.or_insert(serde_json::json!({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
|
||||||
|
body
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for attempt in 0..7 {
|
for attempt in 0..7 {
|
||||||
let response = request.send_string(&body);
|
let response = request.clone().send_json(&body);
|
||||||
let result = check_response(response);
|
let result = check_response(response);
|
||||||
|
|
||||||
let retry_duration = match result {
|
let retry_duration = match result {
|
||||||
Ok(response) => {
|
Ok(response) => return response_to_embedding(response, options, expected_count),
|
||||||
return response_to_embedding(response, &options.response_field, inputs.len())
|
|
||||||
}
|
|
||||||
Err(retry) => {
|
Err(retry) => {
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
tracing::warn!("Failed: {}", retry.error);
|
||||||
retry.into_duration(attempt)
|
retry.into_duration(attempt)
|
||||||
@ -120,11 +228,11 @@ where
|
|||||||
std::thread::sleep(retry_duration);
|
std::thread::sleep(retry_duration);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = request.send_string(&body);
|
let response = request.send_json(&body);
|
||||||
let result = check_response(response);
|
let result = check_response(response);
|
||||||
result
|
result
|
||||||
.map_err(Retry::into_error)
|
.map_err(Retry::into_error)
|
||||||
.and_then(|response| response_to_embedding(response, &options.response_field, inputs.len()))
|
.and_then(|response| response_to_embedding(response, options, expected_count))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
|
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
|
||||||
@ -139,7 +247,10 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
|
|||||||
500..=599 => {
|
500..=599 => {
|
||||||
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
||||||
}
|
}
|
||||||
x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
402..=499 => {
|
||||||
|
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
|
||||||
|
}
|
||||||
|
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(ureq::Error::Transport(transport)) => {
|
Err(ureq::Error::Transport(transport)) => {
|
||||||
@ -148,34 +259,66 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn response_to_embedding<S: AsRef<str>>(
|
fn response_to_embedding(
|
||||||
response: ureq::Response,
|
response: ureq::Response,
|
||||||
response_field: &[S],
|
options: &EmbedderOptions,
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
let response: serde_json::Value =
|
let response: serde_json::Value =
|
||||||
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||||
|
|
||||||
let mut current_value = &response;
|
let mut current_value = &response;
|
||||||
for component in response_field {
|
for component in &options.path_to_embeddings {
|
||||||
let component = component.as_ref();
|
let component = component.as_ref();
|
||||||
let current_value = current_value.get(component).ok_or_else(|| {
|
current_value = current_value.get(component).ok_or_else(|| {
|
||||||
EmbedError::rest_response_missing_embeddings(response, component, response_field)
|
EmbedError::rest_response_missing_embeddings(
|
||||||
|
response.clone(),
|
||||||
|
component,
|
||||||
|
&options.path_to_embeddings,
|
||||||
|
)
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let embeddings = current_value.to_owned();
|
let embeddings = match options.input_type {
|
||||||
|
InputType::Text => {
|
||||||
|
for component in &options.embedding_object {
|
||||||
|
current_value = current_value.get(component).ok_or_else(|| {
|
||||||
|
EmbedError::rest_response_missing_embeddings(
|
||||||
|
response.clone(),
|
||||||
|
component,
|
||||||
|
&options.embedding_object,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
let embeddings = current_value.to_owned();
|
||||||
|
let embeddings: Embedding =
|
||||||
|
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
|
||||||
|
|
||||||
let embeddings: VectorOrArrayOfVectors =
|
vec![Embeddings::from_single_embedding(embeddings)]
|
||||||
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
|
}
|
||||||
|
InputType::TextArray => {
|
||||||
let embeddings = embeddings.into_array_of_vectors();
|
let empty = vec![];
|
||||||
|
let values = current_value.as_array().unwrap_or(&empty);
|
||||||
let embeddings: Vec<Embeddings<f32>> = embeddings
|
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
|
||||||
.into_iter()
|
for value in values {
|
||||||
.flatten()
|
let mut current_value = value;
|
||||||
.map(|embedding| Embeddings::from_single_embedding(embedding))
|
for component in &options.embedding_object {
|
||||||
.collect();
|
current_value = current_value.get(component).ok_or_else(|| {
|
||||||
|
EmbedError::rest_response_missing_embeddings(
|
||||||
|
response.clone(),
|
||||||
|
component,
|
||||||
|
&options.embedding_object,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
let embedding = current_value.to_owned();
|
||||||
|
let embedding: Embedding =
|
||||||
|
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
|
||||||
|
embeddings.push(Embeddings::from_single_embedding(embedding));
|
||||||
|
}
|
||||||
|
embeddings
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if embeddings.len() != expected_count {
|
if embeddings.len() != expected_count {
|
||||||
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
||||||
|
@ -204,7 +204,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
},
|
},
|
||||||
super::EmbedderOptions::Ollama(options) => Self {
|
super::EmbedderOptions::Ollama(options) => Self {
|
||||||
source: Setting::Set(EmbedderSource::Ollama),
|
source: Setting::Set(EmbedderSource::Ollama),
|
||||||
model: Setting::Set(options.embedding_model.name().to_owned()),
|
model: Setting::Set(options.embedding_model.to_owned()),
|
||||||
revision: Setting::NotSet,
|
revision: Setting::NotSet,
|
||||||
api_key: Setting::NotSet,
|
api_key: Setting::NotSet,
|
||||||
dimensions: Setting::NotSet,
|
dimensions: Setting::NotSet,
|
||||||
@ -248,7 +248,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||||||
let mut options: ollama::EmbedderOptions =
|
let mut options: ollama::EmbedderOptions =
|
||||||
super::ollama::EmbedderOptions::with_default_model();
|
super::ollama::EmbedderOptions::with_default_model();
|
||||||
if let Some(model) = model.set() {
|
if let Some(model) = model.set() {
|
||||||
options.embedding_model = super::ollama::EmbeddingModel::from_name(&model);
|
options.embedding_model = model;
|
||||||
}
|
}
|
||||||
this.embedder_options = super::EmbedderOptions::Ollama(options);
|
this.embedder_options = super::EmbedderOptions::Ollama(options);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user