2023-12-12 23:39:01 +01:00
|
|
|
use super::error::EmbedError;
|
2024-03-27 11:50:22 +01:00
|
|
|
use super::{DistributionShift, Embeddings};
|
2023-12-12 23:39:01 +01:00
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
|
|
pub struct Embedder {
|
|
|
|
dimensions: usize,
|
2024-03-27 11:50:22 +01:00
|
|
|
distribution: Option<DistributionShift>,
|
2023-12-12 23:39:01 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
|
|
|
pub struct EmbedderOptions {
|
|
|
|
pub dimensions: usize,
|
2024-03-27 11:50:22 +01:00
|
|
|
pub distribution: Option<DistributionShift>,
|
2023-12-12 23:39:01 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Embedder {
|
|
|
|
pub fn new(options: EmbedderOptions) -> Self {
|
2024-03-27 11:50:22 +01:00
|
|
|
Self { dimensions: options.dimensions, distribution: options.distribution }
|
2023-12-12 23:39:01 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
|
|
|
let Some(text) = texts.pop() else { return Ok(Default::default()) };
|
|
|
|
Err(EmbedError::embed_on_manual_embedder(text))
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn dimensions(&self) -> usize {
|
|
|
|
self.dimensions
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn embed_chunks(
|
|
|
|
&self,
|
|
|
|
text_chunks: Vec<Vec<String>>,
|
|
|
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
|
|
|
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
|
|
|
}
|
2024-03-27 11:50:22 +01:00
|
|
|
|
|
|
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
|
|
|
self.distribution
|
|
|
|
}
|
2023-12-12 23:39:01 +01:00
|
|
|
}
|