Implemented Ollama as an embeddings provider

Initial prototype of Ollama embeddings actually working, error handlign / retries still missing.

Allow model to be any String and require dimensions parameter

Fixed rustfmt formatting issues

There were some formatting issues in the initial PR and this should not make the changes comply with the Rust style guidelines

Because I accidentally didn't follow the style guide for commits in my commit messages I squashed them into one to comply
This commit is contained in:
Jakob Klemm 2024-03-03 01:11:25 +01:00
parent 938149f814
commit d3004d8040
No known key found for this signature in database
GPG Key ID: D91BAB52F26F2A75
7 changed files with 350 additions and 15 deletions

View File

@ -604,6 +604,7 @@ fn embedder_analytics(
EmbedderSource::OpenAi => sources.insert("openAi"),
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSource::UserProvided => sources.insert("userProvided"),
EmbedderSource::Ollama => sources.insert("ollama"),
};
}
};

View File

@ -1178,6 +1178,13 @@ pub fn validate_embedding_settings(
}
}
}
EmbedderSource::Ollama => {
// Existence & corrent dimensions of models cannot easily be checked here.
check_set(&dimensions, "dimensions", inferred_source, name)?;
check_set(&model, "model", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;

View File

@ -2,6 +2,7 @@ use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use super::ollama::OllamaError;
use crate::error::FaultSource;
use crate::vector::openai::OpenAiError;
@ -71,6 +72,15 @@ pub enum EmbedErrorKind {
OpenAiRuntimeInit(std::io::Error),
#[error("initializing web client for sending embedding requests failed: {0}")]
InitWebClient(reqwest::Error),
// Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends.
#[error("unexpected response from Ollama: {0}")]
OllamaUnexpected(reqwest::Error),
#[error("sent too many requests to Ollama: {0}")]
OllamaTooManyRequests(OllamaError),
#[error("received internal error from Ollama: {0}")]
OllamaInternalServerError(OllamaError),
#[error("received unhandled HTTP status code {0} from Ollama")]
OllamaUnhandledStatusCode(u16),
}
impl EmbedError {
@ -129,6 +139,22 @@ impl EmbedError {
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
pub fn ollama_unexpected(inner: reqwest::Error) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug }
}
pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime }
}
pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime }
}
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
}
}
#[derive(Debug, thiserror::Error)]

View File

@ -10,6 +10,8 @@ pub mod manual;
pub mod openai;
pub mod settings;
pub mod ollama;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
@ -76,6 +78,7 @@ pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
Ollama(ollama::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
@ -127,6 +130,7 @@ impl IntoIterator for EmbeddingConfigs {
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
@ -144,6 +148,10 @@ impl EmbedderOptions {
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
pub fn ollama() -> Self {
Self::Ollama(ollama::EmbedderOptions::with_default_model())
}
}
impl Embedder {
@ -151,6 +159,7 @@ impl Embedder {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
@ -167,6 +176,10 @@ impl Embedder {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::Ollama(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
@ -181,6 +194,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
@ -189,6 +203,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -197,6 +212,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -205,6 +221,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
@ -213,6 +230,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}

255
milli/src/vector/ollama.rs Normal file
View File

@ -0,0 +1,255 @@
// Copied from "openai.rs" with the sections I actually understand changed for Ollama.
// The common components of the Ollama and OpenAI interfaces might need to be extracted.
use std::fmt::Display;
use reqwest::StatusCode;
use super::error::{EmbedError, NewEmbedderError};
use super::openai::Retry;
use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)]
pub struct Embedder {
headers: reqwest::header::HeaderMap,
options: EmbedderOptions,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub embedding_model: EmbeddingModel,
pub dimensions: usize,
}
#[derive(
Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr,
)]
#[deserr(deny_unknown_fields)]
pub struct EmbeddingModel {
name: String,
}
#[derive(Debug, serde::Serialize)]
struct OllamaRequest<'a> {
model: &'a str,
prompt: &'a str,
}
#[derive(Debug, serde::Deserialize)]
struct OllamaResponse {
embedding: Embedding,
}
#[derive(Debug, serde::Deserialize)]
struct OllamaErrorResponse {
error: OllamaError,
}
#[derive(Debug, serde::Deserialize)]
pub struct OllamaError {
message: String,
// type: String,
code: Option<String>,
}
impl EmbeddingModel {
pub fn max_token(&self) -> usize {
// this might not be the same for all models
8192
}
pub fn default_dimensions(&self) -> usize {
// Dimensions for nomic-embed-text
768
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn from_name(name: &str) -> Self {
Self { name: name.to_string() }
}
pub fn supports_overriding_dimensions(&self) -> bool {
false
}
}
impl Default for EmbeddingModel {
fn default() -> Self {
Self { name: "nomic-embed-text".to_string() }
}
}
impl EmbedderOptions {
pub fn with_default_model() -> Self {
Self { embedding_model: Default::default(), dimensions: 768 }
}
pub fn with_embedding_model(embedding_model: EmbeddingModel, dimensions: usize) -> Self {
Self { embedding_model, dimensions }
}
}
impl Embedder {
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
reqwest::ClientBuilder::new()
.default_headers(self.headers.clone())
.build()
.map_err(EmbedError::openai_initialize_web_client)
}
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_static("application/json"),
);
Ok(Self { options, headers })
}
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
if !response.status().is_success() {
// Not the same number of possible error cases covered as with OpenAI.
match response.status() {
StatusCode::TOO_MANY_REQUESTS => {
let error_response: OllamaErrorResponse = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests(
error_response.error,
)));
}
StatusCode::SERVICE_UNAVAILABLE => {
let error_response: OllamaErrorResponse = response
.json()
.await
.map_err(EmbedError::ollama_unexpected)
.map_err(Retry::retry_later)?;
return Err(Retry::retry_later(EmbedError::ollama_internal_server_error(
error_response.error,
)));
}
code => {
return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code(
code.as_u16(),
)));
}
}
}
Ok(response)
}
pub async fn embed(
&self,
texts: Vec<String>,
client: &reqwest::Client,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
// Ollama only embedds one document at a time.
let mut results = Vec::with_capacity(texts.len());
// The retry loop is inside the texts loop, might have to switch that around
for text in texts {
// Retries copied from openai.rs
for attempt in 0..7 {
let retry_duration = match self.try_embed(&text, client).await {
Ok(result) => {
results.push(result);
break;
}
Err(retry) => {
tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt)
}
}?;
tracing::warn!(
"Attempt #{}, retrying after {}ms.",
attempt,
retry_duration.as_millis()
);
tokio::time::sleep(retry_duration).await;
}
}
Ok(results)
}
async fn try_embed(
&self,
text: &str,
client: &reqwest::Client,
) -> Result<Embeddings<f32>, Retry> {
let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text };
let response = client
.post(get_ollama_path())
.json(&request)
.send()
.await
.map_err(EmbedError::openai_network)
.map_err(Retry::retry_later)?;
let response = Self::check_response(response).await?;
let response: OllamaResponse = response
.json()
.await
.map_err(EmbedError::openai_unexpected)
.map_err(Retry::retry_later)?;
tracing::trace!("response: {:?}", response.embedding);
let embedding = Embeddings::from_single_embedding(response.embedding);
Ok(embedding)
}
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.map_err(EmbedError::openai_runtime_init)?;
let client = self.new_client()?;
rt.block_on(futures::future::try_join_all(
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
))
}
// Defaults copied from openai.rs
pub fn chunk_count_hint(&self) -> usize {
10
}
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
self.options.dimensions
}
pub fn distribution(&self) -> Option<DistributionShift> {
None
}
}
impl Display for OllamaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.code {
Some(code) => write!(f, "{} ({})", self.message, code),
None => write!(f, "{}", self.message),
}
}
}
fn get_ollama_path() -> String {
// Important: Hostname not enough, has to be entire path to embeddings endpoint
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
}

View File

@ -429,12 +429,12 @@ impl Embedder {
// retrying in case of failure
struct Retry {
error: EmbedError,
pub struct Retry {
pub error: EmbedError,
strategy: RetryStrategy,
}
enum RetryStrategy {
pub enum RetryStrategy {
GiveUp,
Retry,
RetryTokenized,
@ -442,23 +442,23 @@ enum RetryStrategy {
}
impl Retry {
fn give_up(error: EmbedError) -> Self {
pub fn give_up(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::GiveUp }
}
fn retry_later(error: EmbedError) -> Self {
pub fn retry_later(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::Retry }
}
fn retry_tokenized(error: EmbedError) -> Self {
pub fn retry_tokenized(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryTokenized }
}
fn rate_limited(error: EmbedError) -> Self {
pub fn rate_limited(error: EmbedError) -> Self {
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
}
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
pub fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
match self.strategy {
RetryStrategy::GiveUp => Err(self.error),
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
@ -469,11 +469,11 @@ impl Retry {
}
}
fn must_tokenize(&self) -> bool {
pub fn must_tokenize(&self) -> bool {
matches!(self.strategy, RetryStrategy::RetryTokenized)
}
fn into_error(self) -> EmbedError {
pub fn into_error(self) -> EmbedError {
self.error
}
}

View File

@ -1,7 +1,7 @@
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use super::openai;
use super::{ollama, openai};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
@ -80,11 +80,17 @@ impl EmbeddingSettings {
Self::SOURCE => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
}
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::MODEL => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
}
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::DIMENSIONS => {
&[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Ollama]
}
Self::DOCUMENT_TEMPLATE => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
}
_other => unreachable!("unknown field"),
}
}
@ -101,6 +107,9 @@ impl EmbeddingSettings {
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::Ollama => {
&[Self::SOURCE, Self::MODEL, Self::DIMENSIONS, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
}
}
@ -134,6 +143,7 @@ pub enum EmbedderSource {
#[default]
OpenAi,
HuggingFace,
Ollama,
UserProvided,
}
@ -143,6 +153,7 @@ impl std::fmt::Display for EmbedderSource {
EmbedderSource::OpenAi => "openAi",
EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided",
EmbedderSource::Ollama => "ollama",
};
f.write_str(s)
}
@ -192,7 +203,15 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
model: Setting::Set(options.embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
dimensions: Setting::Set(options.dimensions.unwrap_or_default()),
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama),
model: Setting::Set(options.embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions),
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::UserProvided(options) => Self {
@ -229,6 +248,15 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
}
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::Ollama => {
let mut options: ollama::EmbedderOptions =
super::ollama::EmbedderOptions::with_default_model();
if let (Some(model), Some(dim)) = (model.set(), dimensions.set()) {
options.embedding_model = super::ollama::EmbeddingModel::from_name(&model);
options.dimensions = dim;
}
this.embedder_options = super::EmbedderOptions::Ollama(options);
}
EmbedderSource::HuggingFace => {
let mut options = super::hf::EmbedderOptions::default();
if let Some(model) = model.set() {