Various changes

- DistributionShift in Search object (to be set from model in embed?)
- Fix issue where embedder index wasn't computed at search time
- Accept as default embedder either the "default" one, or the only embedder when there is only one
This commit is contained in:
Louis Dureuil 2023-12-13 15:38:44 +01:00
parent 12940d79a9
commit e0cc775dc4
No known key found for this signature in database
12 changed files with 141 additions and 33 deletions

View file

@ -1,5 +1,8 @@
use std::collections::HashMap;
use std::sync::Arc;
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::PromptData;
use crate::prompt::{Prompt, PromptData};
pub mod error;
pub mod hf;
@ -82,6 +85,44 @@ pub struct EmbeddingConfig {
// TODO: add metrics and anything needed
}
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
impl EmbeddingConfigs {
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
Self(data)
}
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.0.get(name).cloned()
}
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
self.get_default_embedder_name().and_then(|default| self.get(&default))
}
pub fn get_default_embedder_name(&self) -> Option<String> {
let mut it = self.0.keys();
let first_name = it.next();
let second_name = it.next();
match (first_name, second_name) {
(None, _) => None,
(Some(first), None) => Some(first.to_owned()),
(Some(_), Some(_)) => Some("default".to_owned()),
}
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>));
type IntoIter = std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),