mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-24 13:54:26 +01:00
187 lines
5.2 KiB
Rust
187 lines
5.2 KiB
Rust
use self::error::{EmbedError, NewEmbedderError};
|
|
use crate::prompt::PromptData;
|
|
|
|
pub mod error;
|
|
pub mod hf;
|
|
pub mod openai;
|
|
pub mod settings;
|
|
|
|
pub use self::error::Error;
|
|
|
|
pub type Embedding = Vec<f32>;
|
|
|
|
pub struct Embeddings<F> {
|
|
data: Vec<F>,
|
|
dimension: usize,
|
|
}
|
|
|
|
impl<F> Embeddings<F> {
|
|
pub fn new(dimension: usize) -> Self {
|
|
Self { data: Default::default(), dimension }
|
|
}
|
|
|
|
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
|
|
Self { dimension: embedding.len(), data: embedding }
|
|
}
|
|
|
|
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
|
|
let mut this = Self::new(dimension);
|
|
this.append(data)?;
|
|
Ok(this)
|
|
}
|
|
|
|
pub fn dimension(&self) -> usize {
|
|
self.dimension
|
|
}
|
|
|
|
pub fn into_inner(self) -> Vec<F> {
|
|
self.data
|
|
}
|
|
|
|
pub fn as_inner(&self) -> &[F] {
|
|
&self.data
|
|
}
|
|
|
|
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
|
|
self.data.as_slice().chunks_exact(self.dimension)
|
|
}
|
|
|
|
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
|
|
if embedding.len() != self.dimension {
|
|
return Err(embedding);
|
|
}
|
|
self.data.append(&mut embedding);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
|
|
if embeddings.len() % self.dimension != 0 {
|
|
return Err(embeddings);
|
|
}
|
|
self.data.append(&mut embeddings);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum Embedder {
|
|
HuggingFace(hf::Embedder),
|
|
OpenAi(openai::Embedder),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
|
pub struct EmbeddingConfig {
|
|
pub embedder_options: EmbedderOptions,
|
|
pub prompt: PromptData,
|
|
// TODO: add metrics and anything needed
|
|
}
|
|
|
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
|
pub enum EmbedderOptions {
|
|
HuggingFace(hf::EmbedderOptions),
|
|
OpenAi(openai::EmbedderOptions),
|
|
}
|
|
|
|
impl Default for EmbedderOptions {
|
|
fn default() -> Self {
|
|
Self::HuggingFace(Default::default())
|
|
}
|
|
}
|
|
|
|
impl EmbedderOptions {
|
|
pub fn huggingface() -> Self {
|
|
Self::HuggingFace(hf::EmbedderOptions::new())
|
|
}
|
|
|
|
pub fn openai(api_key: String) -> Self {
|
|
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
|
}
|
|
}
|
|
|
|
impl Embedder {
|
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
|
Ok(match options {
|
|
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
|
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
|
})
|
|
}
|
|
|
|
pub async fn embed(
|
|
&self,
|
|
texts: Vec<String>,
|
|
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
|
match self {
|
|
Embedder::HuggingFace(embedder) => embedder.embed(texts).await,
|
|
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
|
}
|
|
}
|
|
|
|
pub async fn embed_chunks(
|
|
&self,
|
|
text_chunks: Vec<Vec<String>>,
|
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
|
match self {
|
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await,
|
|
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
|
}
|
|
}
|
|
|
|
pub fn chunk_count_hint(&self) -> usize {
|
|
match self {
|
|
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
|
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
|
}
|
|
}
|
|
|
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
|
match self {
|
|
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
|
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
pub struct DistributionShift {
|
|
pub current_mean: f32,
|
|
pub current_sigma: f32,
|
|
}
|
|
|
|
impl DistributionShift {
|
|
/// `None` if sigma <= 0.
|
|
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
|
|
if sigma <= 0.0 {
|
|
None
|
|
} else {
|
|
Some(Self { current_mean: mean, current_sigma: sigma })
|
|
}
|
|
}
|
|
|
|
pub fn shift(&self, score: f32) -> f32 {
|
|
// <https://math.stackexchange.com/a/2894689>
|
|
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
|
// The parameters we're given is the mean and sigma of the native result distribution.
|
|
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
|
|
|
|
let target_mean = 0.5;
|
|
let target_sigma = 0.4;
|
|
|
|
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
|
let factor = target_sigma / self.current_sigma;
|
|
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
|
let offset = target_mean - (factor * self.current_mean);
|
|
|
|
let mut score = factor * score + offset;
|
|
|
|
// clamp the final score in the ]0, 1] interval.
|
|
if score <= 0.0 {
|
|
score = f32::EPSILON;
|
|
}
|
|
if score > 1.0 {
|
|
score = 1.0;
|
|
}
|
|
|
|
score
|
|
}
|
|
}
|