Expose REST embedder to the API

This commit is contained in:
Louis Dureuil 2024-03-25 10:05:38 +01:00
parent f87747f4d3
commit a1db342f01
No known key found for this signature in database
7 changed files with 357 additions and 32 deletions

View file

@ -1,6 +1,9 @@
use std::collections::HashMap;
use std::sync::Arc;
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
@ -104,7 +107,10 @@ pub enum Embedder {
OpenAi(openai::Embedder),
/// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder),
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
Ollama(ollama::Embedder),
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
Rest(rest::Embedder),
}
/// Configuration for an embedder.
@ -175,6 +181,7 @@ pub enum EmbedderOptions {
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
Rest(rest::EmbedderOptions),
}
impl Default for EmbedderOptions {
@ -209,6 +216,7 @@ impl Embedder {
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
})
}
@ -224,6 +232,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.embed(texts),
Embedder::Ollama(embedder) => embedder.embed(texts),
Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts),
}
}
@ -240,6 +249,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
}
}
@ -250,6 +260,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
}
}
@ -260,6 +271,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
}
}
@ -270,6 +282,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
Embedder::Rest(embedder) => embedder.dimensions(),
}
}
@ -280,6 +293,7 @@ impl Embedder {
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
Embedder::Rest(embedder) => embedder.distribution(),
}
}
}
@ -288,17 +302,47 @@ impl Embedder {
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
/// This allows to correct effects where results are too "packed" around a certain value.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "DistributionShiftSerializable")]
#[serde(into = "DistributionShiftSerializable")]
pub struct DistributionShift {
/// Value where the results are "packed".
///
/// Similarity scores are translated so that they are packed around 0.5 instead
pub current_mean: f32,
pub current_mean: OrderedFloat<f32>,
/// standard deviation of a similarity score.
///
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
pub current_sigma: f32,
pub current_sigma: OrderedFloat<f32>,
}
#[derive(Serialize, Deserialize)]
struct DistributionShiftSerializable {
current_mean: f32,
current_sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
fn from(
DistributionShift {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { current_mean, current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
}
}
impl DistributionShift {
@ -307,11 +351,13 @@ impl DistributionShift {
if sigma <= 0.0 {
None
} else {
Some(Self { current_mean: mean, current_sigma: sigma })
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
}
}
pub fn shift(&self, score: f32) -> f32 {
let current_mean = self.current_mean.0;
let current_sigma = self.current_sigma.0;
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
@ -321,9 +367,9 @@ impl DistributionShift {
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
let factor = target_sigma / self.current_sigma;
let factor = target_sigma / current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
let offset = target_mean - (factor * self.current_mean);
let offset = target_mean - (factor * current_mean);
let mut score = factor * score + offset;