WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View file

@ -62,6 +62,7 @@ pub struct Embedder {
model: BertModel,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
}
impl std::fmt::Debug for Embedder {
@ -126,10 +127,17 @@ impl Embedder {
tokenizer.with_padding(Some(pp));
}
Ok(Self { model, tokenizer, options })
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
}
pub async fn embed(
pub fn embed(
&self,
mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
@ -170,12 +178,11 @@ impl Embedder {
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
pub async fn embed_chunks(
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
.await
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn chunk_count_hint(&self) -> usize {
@ -185,6 +192,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {