use self::error::{EmbedError, NewEmbedderError}; use crate::prompt::PromptData; pub mod error; pub mod hf; pub mod openai; pub mod settings; pub use self::error::Error; pub type Embedding = Vec; pub struct Embeddings { data: Vec, dimension: usize, } impl Embeddings { pub fn new(dimension: usize) -> Self { Self { data: Default::default(), dimension } } pub fn from_single_embedding(embedding: Vec) -> Self { Self { dimension: embedding.len(), data: embedding } } pub fn from_inner(data: Vec, dimension: usize) -> Result> { let mut this = Self::new(dimension); this.append(data)?; Ok(this) } pub fn dimension(&self) -> usize { self.dimension } pub fn into_inner(self) -> Vec { self.data } pub fn as_inner(&self) -> &[F] { &self.data } pub fn iter(&self) -> impl Iterator + '_ { self.data.as_slice().chunks_exact(self.dimension) } pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { if embedding.len() != self.dimension { return Err(embedding); } self.data.append(&mut embedding); Ok(()) } pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { if embeddings.len() % self.dimension != 0 { return Err(embeddings); } self.data.append(&mut embeddings); Ok(()) } } #[derive(Debug)] pub enum Embedder { HuggingFace(hf::Embedder), OpenAi(openai::Embedder), } #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] pub struct EmbeddingConfig { pub embedder_options: EmbedderOptions, pub prompt: PromptData, // TODO: add metrics and anything needed } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), OpenAi(openai::EmbedderOptions), } impl Default for EmbedderOptions { fn default() -> Self { Self::HuggingFace(Default::default()) } } impl EmbedderOptions { pub fn huggingface() -> Self { Self::HuggingFace(hf::EmbedderOptions::new()) } pub fn openai(api_key: String) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } } impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), }) } pub async fn embed( &self, texts: Vec, ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts).await, Embedder::OpenAi(embedder) => embedder.embed(texts).await, } } pub async fn embed_chunks( &self, text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, } } pub fn chunk_count_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), } } pub fn prompt_count_in_chunk_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), } } } #[derive(Debug, Clone, Copy)] pub struct DistributionShift { pub current_mean: f32, pub current_sigma: f32, } impl DistributionShift { /// `None` if sigma <= 0. pub fn new(mean: f32, sigma: f32) -> Option { if sigma <= 0.0 { None } else { Some(Self { current_mean: mean, current_sigma: sigma }) } } pub fn shift(&self, score: f32) -> f32 { // // We're somewhat abusively mapping the distribution of distances to a gaussian. // The parameters we're given is the mean and sigma of the native result distribution. // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4. let target_mean = 0.5; let target_sigma = 0.4; // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. let factor = target_sigma / self.current_sigma; // a*mu1 + b = mu2 => b = mu2 - a*mu1 let offset = target_mean - (factor * self.current_mean); let mut score = factor * score + offset; // clamp the final score in the ]0, 1] interval. if score <= 0.0 { score = f32::EPSILON; } if score > 1.0 { score = 1.0; } score } }