Changes to prompt

This commit is contained in:
Louis Dureuil 2024-09-03 12:07:10 +02:00
parent de962a26f3
commit c49d892c82
No known key found for this signature in database

View File

@ -6,6 +6,7 @@ mod template_checker;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::num::NonZeroUsize;
use std::ops::Deref; use std::ops::Deref;
use error::{NewPromptError, RenderPromptError}; use error::{NewPromptError, RenderPromptError};
@ -18,16 +19,18 @@ use crate::{FieldId, FieldsIdsMap};
pub struct Prompt { pub struct Prompt {
template: liquid::Template, template: liquid::Template,
template_text: String, template_text: String,
max_bytes: Option<NonZeroUsize>,
} }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PromptData { pub struct PromptData {
pub template: String, pub template: String,
pub max_bytes: Option<NonZeroUsize>,
} }
impl From<Prompt> for PromptData { impl From<Prompt> for PromptData {
fn from(value: Prompt) -> Self { fn from(value: Prompt) -> Self {
Self { template: value.template_text } Self { template: value.template_text, max_bytes: value.max_bytes }
} }
} }
@ -35,14 +38,18 @@ impl TryFrom<PromptData> for Prompt {
type Error = NewPromptError; type Error = NewPromptError;
fn try_from(value: PromptData) -> Result<Self, Self::Error> { fn try_from(value: PromptData) -> Result<Self, Self::Error> {
Prompt::new(value.template) Prompt::new(value.template, value.max_bytes)
} }
} }
impl Clone for Prompt { impl Clone for Prompt {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let template_text = self.template_text.clone(); let template_text = self.template_text.clone();
Self { template: new_template(&template_text).unwrap(), template_text } Self {
template: new_template(&template_text).unwrap(),
template_text,
max_bytes: self.max_bytes,
}
} }
} }
@ -62,20 +69,28 @@ fn default_template_text() -> &'static str {
{% endfor %}" {% endfor %}"
} }
pub fn default_max_bytes() -> NonZeroUsize {
NonZeroUsize::new(400).unwrap()
}
impl Default for Prompt { impl Default for Prompt {
fn default() -> Self { fn default() -> Self {
Self { template: default_template(), template_text: default_template_text().into() } Self {
template: default_template(),
template_text: default_template_text().into(),
max_bytes: Some(default_max_bytes()),
}
} }
} }
impl Default for PromptData { impl Default for PromptData {
fn default() -> Self { fn default() -> Self {
Self { template: default_template_text().into() } Self { template: default_template_text().into(), max_bytes: Some(default_max_bytes()) }
} }
} }
impl Prompt { impl Prompt {
pub fn new(template: String) -> Result<Self, NewPromptError> { pub fn new(template: String, max_bytes: Option<NonZeroUsize>) -> Result<Self, NewPromptError> {
let this = Self { let this = Self {
template: liquid::ParserBuilder::with_stdlib() template: liquid::ParserBuilder::with_stdlib()
.build() .build()
@ -83,6 +98,7 @@ impl Prompt {
.parse(&template) .parse(&template)
.map_err(NewPromptError::cannot_parse_template)?, .map_err(NewPromptError::cannot_parse_template)?,
template_text: template, template_text: template,
max_bytes,
}; };
// render template with special object that's OK with `doc.*` and `fields.*` // render template with special object that's OK with `doc.*` and `fields.*`
@ -102,7 +118,24 @@ impl Prompt {
let document = Document::new(document, side, field_id_map); let document = Document::new(document, side, field_id_map);
let context = Context::new(&document, field_id_map); let context = Context::new(&document, field_id_map);
self.template.render(&context).map_err(RenderPromptError::missing_context) let mut rendered =
self.template.render(&context).map_err(RenderPromptError::missing_context)?;
if let Some(max_bytes) = self.max_bytes {
truncate(&mut rendered, max_bytes.get());
}
Ok(rendered)
}
}
fn truncate(s: &mut String, max_bytes: usize) {
if max_bytes >= s.len() {
return;
}
for i in (0..=max_bytes).rev() {
if s.is_char_boundary(i) {
s.truncate(i);
break;
}
} }
} }
@ -145,6 +178,7 @@ mod test {
use super::Prompt; use super::Prompt;
use crate::error::FaultSource; use crate::error::FaultSource;
use crate::prompt::error::{NewPromptError, NewPromptErrorKind}; use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
use crate::prompt::truncate;
#[test] #[test]
fn default_template() { fn default_template() {
@ -154,18 +188,18 @@ mod test {
#[test] #[test]
fn empty_template() { fn empty_template() {
Prompt::new("".into()).unwrap(); Prompt::new("".into(), None).unwrap();
} }
#[test] #[test]
fn template_ok() { fn template_ok() {
Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap(); Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None).unwrap();
} }
#[test] #[test]
fn template_syntax() { fn template_syntax() {
assert!(matches!( assert!(matches!(
Prompt::new("{{doc.title: {{doc.overview}}".into()), Prompt::new("{{doc.title: {{doc.overview}}".into(), None),
Err(NewPromptError { Err(NewPromptError {
kind: NewPromptErrorKind::CannotParseTemplate(_), kind: NewPromptErrorKind::CannotParseTemplate(_),
fault: FaultSource::User fault: FaultSource::User
@ -176,7 +210,7 @@ mod test {
#[test] #[test]
fn template_missing_doc() { fn template_missing_doc() {
assert!(matches!( assert!(matches!(
Prompt::new("{{title}}: {{overview}}".into()), Prompt::new("{{title}}: {{overview}}".into(), None),
Err(NewPromptError { Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User fault: FaultSource::User
@ -186,17 +220,20 @@ mod test {
#[test] #[test]
fn template_nested_doc() { fn template_nested_doc() {
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap(); Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None).unwrap();
} }
#[test] #[test]
fn template_fields() { fn template_fields() {
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap(); Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None).unwrap();
} }
#[test] #[test]
fn template_fields_ok() { fn template_fields_ok() {
Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into()) Prompt::new(
"{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
None,
)
.unwrap(); .unwrap();
} }
@ -204,11 +241,41 @@ mod test {
fn template_fields_invalid() { fn template_fields_invalid() {
assert!(matches!( assert!(matches!(
// intentionally garbled field // intentionally garbled field
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()), Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None),
Err(NewPromptError { Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User fault: FaultSource::User
}) })
)); ));
} }
// todo: test truncation
#[test]
fn template_truncation() {
let mut s = "インテル ザー ビーグル".to_string();
truncate(&mut s, 42);
assert_eq!(s, "インテル ザー ビーグル");
assert_eq!(s.len(), 32);
truncate(&mut s, 32);
assert_eq!(s, "インテル ザー ビーグル");
truncate(&mut s, 31);
assert_eq!(s, "インテル ザー ビーグ");
truncate(&mut s, 30);
assert_eq!(s, "インテル ザー ビーグ");
truncate(&mut s, 28);
assert_eq!(s, "インテル ザー ビー");
truncate(&mut s, 26);
assert_eq!(s, "インテル ザー ビー");
truncate(&mut s, 25);
assert_eq!(s, "インテル ザー ビ");
assert_eq!("".len(), 3);
truncate(&mut s, 3);
assert_eq!(s, "");
truncate(&mut s, 2);
assert_eq!(s, "");
}
} }