Add maxBytes parameter

This commit is contained in:
Louis Dureuil 2024-08-27 17:52:09 +02:00
parent c49d892c82
commit 1ac008926b
No known key found for this signature in database
3 changed files with 62 additions and 5 deletions

View File

@ -2740,6 +2740,7 @@ mod tests {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(3), dimensions: Setting::Set(3),
document_template: Setting::NotSet, document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet, url: Setting::NotSet,
request: Setting::NotSet, request: Setting::NotSet,
response: Setting::NotSet, response: Setting::NotSet,

View File

@ -1,5 +1,6 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::convert::TryInto; use std::convert::TryInto;
use std::num::NonZeroUsize;
use std::result::Result as StdResult; use std::result::Result as StdResult;
use std::sync::Arc; use std::sync::Arc;
@ -19,6 +20,7 @@ use crate::index::{
IndexEmbeddingConfig, DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS, IndexEmbeddingConfig, DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS,
}; };
use crate::order_by_map::OrderByMap; use crate::order_by_map::OrderByMap;
use crate::prompt::default_max_bytes;
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod; use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::update::{IndexDocuments, UpdateIndexingStep};
@ -1573,16 +1575,30 @@ fn validate_prompt(
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
document_template_max_bytes,
url, url,
request, request,
response, response,
distribution, distribution,
headers, headers,
}) => { }) => {
let max_bytes = match document_template_max_bytes.set() {
Some(max_bytes) => NonZeroUsize::new(max_bytes).ok_or_else(|| {
crate::error::UserError::InvalidSettingsDocumentTemplateMaxBytes {
embedder_name: name.to_owned(),
}
})?,
None => default_max_bytes(),
};
// validate // validate
let template = crate::prompt::Prompt::new(template) let template = crate::prompt::Prompt::new(
.map(|prompt| crate::prompt::PromptData::from(prompt).template) template,
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; // always specify a max_bytes
Some(max_bytes),
)
.map(|prompt| crate::prompt::PromptData::from(prompt).template)
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
Ok(Setting::Set(EmbeddingSettings { Ok(Setting::Set(EmbeddingSettings {
source, source,
@ -1591,6 +1607,7 @@ fn validate_prompt(
api_key, api_key,
dimensions, dimensions,
document_template: Setting::Set(template), document_template: Setting::Set(template),
document_template_max_bytes,
url, url,
request, request,
response, response,
@ -1615,6 +1632,7 @@ pub fn validate_embedding_settings(
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
document_template_max_bytes,
url, url,
request, request,
response, response,
@ -1654,6 +1672,7 @@ pub fn validate_embedding_settings(
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
document_template_max_bytes,
url, url,
request, request,
response, response,
@ -1726,6 +1745,12 @@ pub fn validate_embedding_settings(
inferred_source, inferred_source,
name, name,
)?; )?;
check_unset(
&document_template_max_bytes,
EmbeddingSettings::DOCUMENT_TEMPLATE_MAX_BYTES,
inferred_source,
name,
)?;
check_set(&dimensions, EmbeddingSettings::DIMENSIONS, inferred_source, name)?; check_set(&dimensions, EmbeddingSettings::DIMENSIONS, inferred_source, name)?;
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?; check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
@ -1748,6 +1773,7 @@ pub fn validate_embedding_settings(
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
document_template_max_bytes,
url, url,
request, request,
response, response,

View File

@ -1,11 +1,12 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::num::NonZeroUsize;
use deserr::Deserr; use deserr::Deserr;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{ollama, openai, DistributionShift}; use super::{ollama, openai, DistributionShift};
use crate::prompt::PromptData; use crate::prompt::{default_max_bytes, PromptData};
use crate::update::Setting; use crate::update::Setting;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
use crate::UserError; use crate::UserError;
@ -34,6 +35,9 @@ pub struct EmbeddingSettings {
pub document_template: Setting<String>, pub document_template: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub document_template_max_bytes: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub url: Setting<String>, pub url: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
@ -111,6 +115,7 @@ impl SettingsDiff {
mut response, mut response,
mut distribution, mut distribution,
mut headers, mut headers,
mut document_template_max_bytes,
} = old; } = old;
let EmbeddingSettings { let EmbeddingSettings {
@ -125,6 +130,7 @@ impl SettingsDiff {
response: new_response, response: new_response,
distribution: new_distribution, distribution: new_distribution,
headers: new_headers, headers: new_headers,
document_template_max_bytes: new_document_template_max_bytes,
} = new; } = new;
let mut reindex_action = None; let mut reindex_action = None;
@ -142,6 +148,7 @@ impl SettingsDiff {
&mut request, &mut request,
&mut response, &mut response,
&mut document_template, &mut document_template,
&mut document_template_max_bytes,
&mut headers, &mut headers,
) )
} }
@ -189,6 +196,12 @@ impl SettingsDiff {
ReindexAction::RegeneratePrompts, ReindexAction::RegeneratePrompts,
); );
} }
if document_template_max_bytes.apply(new_document_template_max_bytes) {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::RegeneratePrompts,
)
}
distribution.apply(new_distribution); distribution.apply(new_distribution);
api_key.apply(new_api_key); api_key.apply(new_api_key);
@ -206,6 +219,7 @@ impl SettingsDiff {
response, response,
distribution, distribution,
headers, headers,
document_template_max_bytes,
}; };
match reindex_action { match reindex_action {
@ -239,6 +253,7 @@ fn apply_default_for_source(
request: &mut Setting<serde_json::Value>, request: &mut Setting<serde_json::Value>,
response: &mut Setting<serde_json::Value>, response: &mut Setting<serde_json::Value>,
document_template: &mut Setting<String>, document_template: &mut Setting<String>,
document_template_max_bytes: &mut Setting<usize>,
headers: &mut Setting<BTreeMap<String, String>>, headers: &mut Setting<BTreeMap<String, String>>,
) { ) {
match source { match source {
@ -286,6 +301,7 @@ fn apply_default_for_source(
*request = Setting::NotSet; *request = Setting::NotSet;
*response = Setting::NotSet; *response = Setting::NotSet;
*document_template = Setting::NotSet; *document_template = Setting::NotSet;
*document_template_max_bytes = Setting::NotSet;
*headers = Setting::NotSet; *headers = Setting::NotSet;
} }
Setting::NotSet => {} Setting::NotSet => {}
@ -316,6 +332,7 @@ impl EmbeddingSettings {
pub const API_KEY: &'static str = "apiKey"; pub const API_KEY: &'static str = "apiKey";
pub const DIMENSIONS: &'static str = "dimensions"; pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub const DOCUMENT_TEMPLATE_MAX_BYTES: &'static str = "documentTemplateMaxBytes";
pub const URL: &'static str = "url"; pub const URL: &'static str = "url";
pub const REQUEST: &'static str = "request"; pub const REQUEST: &'static str = "request";
@ -459,6 +476,8 @@ impl std::fmt::Display for EmbedderSource {
impl From<EmbeddingConfig> for EmbeddingSettings { impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self { fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt } = value; let EmbeddingConfig { embedder_options, prompt } = value;
let document_template_max_bytes =
Setting::Set(prompt.max_bytes.unwrap_or(default_max_bytes()).get());
match embedder_options { match embedder_options {
super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions { super::EmbedderOptions::HuggingFace(super::hf::EmbedderOptions {
model, model,
@ -471,6 +490,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::NotSet, dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::NotSet, url: Setting::NotSet,
request: Setting::NotSet, request: Setting::NotSet,
response: Setting::NotSet, response: Setting::NotSet,
@ -490,6 +510,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::some_or_not_set(url), url: Setting::some_or_not_set(url),
request: Setting::NotSet, request: Setting::NotSet,
response: Setting::NotSet, response: Setting::NotSet,
@ -509,6 +530,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::some_or_not_set(url), url: Setting::some_or_not_set(url),
request: Setting::NotSet, request: Setting::NotSet,
response: Setting::NotSet, response: Setting::NotSet,
@ -525,6 +547,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::NotSet, api_key: Setting::NotSet,
dimensions: Setting::Set(dimensions), dimensions: Setting::Set(dimensions),
document_template: Setting::NotSet, document_template: Setting::NotSet,
document_template_max_bytes: Setting::NotSet,
url: Setting::NotSet, url: Setting::NotSet,
request: Setting::NotSet, request: Setting::NotSet,
response: Setting::NotSet, response: Setting::NotSet,
@ -546,6 +569,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: Setting::some_or_not_set(api_key), api_key: Setting::some_or_not_set(api_key),
dimensions: Setting::some_or_not_set(dimensions), dimensions: Setting::some_or_not_set(dimensions),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
document_template_max_bytes,
url: Setting::Set(url), url: Setting::Set(url),
request: Setting::Set(request), request: Setting::Set(request),
response: Setting::Set(response), response: Setting::Set(response),
@ -566,6 +590,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
api_key, api_key,
dimensions, dimensions,
document_template, document_template,
document_template_max_bytes,
url, url,
request, request,
response, response,
@ -648,7 +673,12 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
} }
if let Setting::Set(template) = document_template { if let Setting::Set(template) = document_template {
this.prompt = PromptData { template } let max_bytes = document_template_max_bytes
.set()
.and_then(NonZeroUsize::new)
.unwrap_or(default_max_bytes());
this.prompt = PromptData { template, max_bytes: Some(max_bytes) }
} }
this this