// Copied from "openai.rs" with the sections I actually understand changed for Ollama. // The common components of the Ollama and OpenAI interfaces might need to be extracted. use std::fmt::Display; use reqwest::StatusCode; use super::error::{EmbedError, NewEmbedderError}; use super::openai::Retry; use super::{DistributionShift, Embedding, Embeddings}; #[derive(Debug)] pub struct Embedder { headers: reqwest::header::HeaderMap, options: EmbedderOptions, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub embedding_model: EmbeddingModel, } #[derive( Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, )] #[deserr(deny_unknown_fields)] pub struct EmbeddingModel { name: String, dimensions: usize, } #[derive(Debug, serde::Serialize)] struct OllamaRequest<'a> { model: &'a str, prompt: &'a str, } #[derive(Debug, serde::Deserialize)] struct OllamaResponse { embedding: Embedding, } #[derive(Debug, serde::Deserialize)] pub struct OllamaError { error: String, } impl EmbeddingModel { pub fn max_token(&self) -> usize { // this might not be the same for all models 8192 } pub fn default_dimensions(&self) -> usize { // Dimensions for nomic-embed-text 768 } pub fn name(&self) -> String { self.name.clone() } pub fn from_name(name: &str) -> Self { Self { name: name.to_string(), dimensions: 0 } } pub fn supports_overriding_dimensions(&self) -> bool { false } } impl Default for EmbeddingModel { fn default() -> Self { Self { name: "nomic-embed-text".to_string(), dimensions: 0 } } } impl EmbedderOptions { pub fn with_default_model() -> Self { Self { embedding_model: Default::default() } } pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { Self { embedding_model } } } impl Embedder { pub fn new_client(&self) -> Result { reqwest::ClientBuilder::new() .default_headers(self.headers.clone()) .build() .map_err(EmbedError::openai_initialize_web_client) } pub fn new(options: EmbedderOptions) -> Result { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, reqwest::header::HeaderValue::from_static("application/json"), ); let mut embedder = Self { options, headers }; let rt = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .map_err(EmbedError::openai_runtime_init) .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; // Get dimensions from Ollama let request = OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; // TODO: Refactor into shared error type let client = embedder .new_client() .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; rt.block_on(async move { let response = client .post(get_ollama_path()) .json(&request) .send() .await .map_err(EmbedError::ollama_unexpected) .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; // Process error in case model not found let response = Self::check_response(response).await.map_err(|_err| { let e = EmbedError::ollama_model_not_found(OllamaError { error: format!("model: {}", embedder.options.embedding_model.name()), }); NewEmbedderError::ollama_could_not_determine_dimension(e) })?; let response: OllamaResponse = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; let embedding = Embeddings::from_single_embedding(response.embedding); embedder.options.embedding_model.dimensions = embedding.dimension(); tracing::info!( "ollama model {} with dimensionality {} added", embedder.options.embedding_model.name(), embedding.dimension() ); Ok(embedder) }) } async fn check_response(response: reqwest::Response) -> Result { if !response.status().is_success() { // Not the same number of possible error cases covered as with OpenAI. match response.status() { StatusCode::TOO_MANY_REQUESTS => { let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( OllamaError { error: error_response.error }, ))); } StatusCode::SERVICE_UNAVAILABLE => { let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( OllamaError { error: error_response.error }, ))); } StatusCode::NOT_FOUND => { let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::give_up)?; return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { error: error_response.error, }))); } code => { return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( code.as_u16(), ))); } } } Ok(response) } pub async fn embed( &self, texts: Vec, client: &reqwest::Client, ) -> Result>, EmbedError> { // Ollama only embedds one document at a time. let mut results = Vec::with_capacity(texts.len()); // The retry loop is inside the texts loop, might have to switch that around for text in texts { // Retries copied from openai.rs for attempt in 0..7 { let retry_duration = match self.try_embed(&text, client).await { Ok(result) => { results.push(result); break; } Err(retry) => { tracing::warn!("Failed: {}", retry.error); retry.into_duration(attempt) } }?; tracing::warn!( "Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis() ); tokio::time::sleep(retry_duration).await; } } Ok(results) } async fn try_embed( &self, text: &str, client: &reqwest::Client, ) -> Result, Retry> { let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; let response = client .post(get_ollama_path()) .json(&request) .send() .await .map_err(EmbedError::openai_network) .map_err(Retry::retry_later)?; let response = Self::check_response(response).await?; let response: OllamaResponse = response .json() .await .map_err(EmbedError::openai_unexpected) .map_err(Retry::retry_later)?; tracing::trace!("response: {:?}", response.embedding); let embedding = Embeddings::from_single_embedding(response.embedding); Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, ) -> Result>>, EmbedError> { let rt = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .map_err(EmbedError::openai_runtime_init)?; let client = self.new_client()?; rt.block_on(futures::future::try_join_all( text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), )) } // Defaults copied from openai.rs pub fn chunk_count_hint(&self) -> usize { 10 } pub fn prompt_count_in_chunk_hint(&self) -> usize { 10 } pub fn dimensions(&self) -> usize { self.options.embedding_model.dimensions } pub fn distribution(&self) -> Option { None } } impl Display for OllamaError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.error) } } fn get_ollama_path() -> String { // Important: Hostname not enough, has to be entire path to embeddings endpoint std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) }