add prompt and context

This commit is contained in:
Louis Dureuil 2024-04-10 09:43:33 +02:00
parent f505fa4ae8
commit 9cef8ec087
No known key found for this signature in database
7 changed files with 279 additions and 24 deletions

View File

@ -245,7 +245,9 @@ InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
InvalidRecommendContext , InvalidRequest , BAD_REQUEST ;
InvalidRecommendId , InvalidRequest , BAD_REQUEST ;
InvalidRecommendPrompt , InvalidRequest , BAD_REQUEST ;
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ;
@ -309,6 +311,8 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
MissingIndexUid , InvalidRequest , BAD_REQUEST ;
MissingMasterKey , Auth , UNAUTHORIZED ;
MissingPayload , InvalidRequest , BAD_REQUEST ;
MissingPrompt , InvalidRequest , BAD_REQUEST ;
MissingPromptOrId , InvalidRequest , BAD_REQUEST ;
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
MissingTaskFilters , InvalidRequest , BAD_REQUEST ;

View File

@ -61,6 +61,10 @@ pub enum MeilisearchHttpError {
Join(#[from] JoinError),
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
MissingSearchHybrid,
#[error("Invalid request: `prompt` parameter is required when `context` is present.")]
RecommendMissingPrompt,
#[error("Invalid request: one of the `prompt` or `id` parameters is required.")]
RecommendMissingPromptOrId,
}
impl ErrorCode for MeilisearchHttpError {
@ -89,6 +93,8 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
MeilisearchHttpError::Join(_) => Code::Internal,
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
MeilisearchHttpError::RecommendMissingPrompt => Code::MissingPrompt,
MeilisearchHttpError::RecommendMissingPromptOrId => Code::MissingPromptOrId,
}
}
}

View File

@ -316,7 +316,7 @@ impl SearchQueryWithIndex {
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct RecommendQuery {
#[deserr(default, error = DeserrJsonError<InvalidRecommendId>)]
pub id: String,
pub id: Option<String>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@ -331,6 +331,11 @@ pub struct RecommendQuery {
pub show_ranking_score: bool,
#[deserr(default, error = DeserrJsonError<InvalidSearchShowRankingScoreDetails>, default)]
pub show_ranking_score_details: bool,
#[deserr(default, error = DeserrJsonError<InvalidRecommendPrompt>)]
pub prompt: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidRecommendContext>)]
pub context: Option<Value>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)]
@ -418,7 +423,8 @@ pub struct SearchResult {
#[serde(rename_all = "camelCase")]
pub struct RecommendResult {
pub hits: Vec<SearchHit>,
pub id: String,
pub id: Option<String>,
pub prompt: Option<String>,
pub processing_time_ms: u128,
#[serde(flatten)]
pub hits_info: HitsInfo,
@ -836,20 +842,41 @@ pub fn perform_recommend(
let before_search = Instant::now();
let rtxn = index.read_txn()?;
let internal_id = index
.external_documents_ids()
.get(&rtxn, &query.id)?
.ok_or_else(|| MeilisearchHttpError::DocumentNotFound(query.id.clone()))?;
let internal_id = query
.id
.as_deref()
.map(|id| -> Result<_, MeilisearchHttpError> {
Ok(index
.external_documents_ids()
.get(&rtxn, id)?
.ok_or_else(|| MeilisearchHttpError::DocumentNotFound(id.to_owned()))?)
})
.transpose()?;
let mut recommend = milli::Recommend::new(
internal_id,
query.offset,
query.limit,
index,
&rtxn,
embedder_name,
embedder,
);
let mut recommend = match (query.prompt.as_deref(), internal_id, query.context) {
(None, Some(internal_id), None) => milli::Recommend::with_docid(
internal_id,
query.offset,
query.limit,
index,
&rtxn,
embedder_name,
embedder,
),
(Some(prompt), internal_id, context) => milli::Recommend::with_prompt(
prompt,
internal_id,
context,
query.offset,
query.limit,
index,
&rtxn,
embedder_name,
embedder,
),
(None, _, Some(_)) => return Err(MeilisearchHttpError::RecommendMissingPrompt.into()),
(None, None, None) => return Err(MeilisearchHttpError::RecommendMissingPromptOrId.into()),
};
if let Some(ref filter) = query.filter {
if let Some(facets) = parse_filter(filter)? {
@ -947,6 +974,7 @@ pub fn perform_recommend(
hits: documents,
hits_info,
id: query.id,
prompt: query.prompt,
processing_time_ms: before_search.elapsed().as_millis(),
};
Ok(result)

View File

@ -29,7 +29,7 @@ impl ParsedValue {
}
impl<'a> Document<'a> {
pub fn new(
pub fn from_deladd_obkv(
data: obkv::KvReaderU16<'a>,
side: DelAdd,
inverted_field_map: &'a FieldsIdsMap,
@ -48,6 +48,20 @@ impl<'a> Document<'a> {
Self(out_data)
}
pub fn from_doc_obkv(
data: obkv::KvReaderU16<'a>,
inverted_field_map: &'a FieldsIdsMap,
) -> Self {
let mut out_data = BTreeMap::new();
for (fid, raw) in data {
let Some(name) = inverted_field_map.name(fid) else {
continue;
};
out_data.insert(name, (raw, ParsedValue::empty()));
}
Self(out_data)
}
fn is_empty(&self) -> bool {
self.0.is_empty()
}

View File

@ -2,6 +2,7 @@ mod context;
mod document;
pub(crate) mod error;
mod fields;
pub mod recommend;
mod template_checker;
use std::convert::TryFrom;
@ -9,7 +10,7 @@ use std::convert::TryFrom;
use error::{NewPromptError, RenderPromptError};
use self::context::Context;
use self::document::Document;
pub use self::document::Document;
use crate::update::del_add::DelAdd;
use crate::FieldsIdsMap;
@ -95,7 +96,7 @@ impl Prompt {
side: DelAdd,
field_id_map: &FieldsIdsMap,
) -> Result<String, RenderPromptError> {
let document = Document::new(document, side, field_id_map);
let document = Document::from_deladd_obkv(document, side, field_id_map);
let context = Context::new(&document, field_id_map);
self.template.render(&context).map_err(RenderPromptError::missing_context)

View File

@ -0,0 +1,112 @@
use liquid::model::{
DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
};
use liquid::{ObjectView, ValueView};
use super::document::Document;
#[derive(Clone, Debug)]
pub struct Context<'a> {
document: Option<&'a Document<'a>>,
context: Option<liquid::Object>,
}
impl<'a> Context<'a> {
pub fn new(document: Option<&'a Document<'a>>, context: Option<serde_json::Value>) -> Self {
/// FIXME: unwrap
let context = context.map(|context| liquid::to_object(&context).unwrap());
Self { document, context }
}
}
impl<'a> ObjectView for Context<'a> {
fn as_value(&self) -> &dyn ValueView {
self
}
fn size(&self) -> i64 {
match (self.context.as_ref(), self.document.as_ref()) {
(None, None) => 0,
(None, Some(_)) => 1,
(Some(_), None) => 1,
(Some(_), Some(_)) => 2,
}
}
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
let keys = match (self.context.as_ref(), self.document.as_ref()) {
(None, None) => [].as_slice(),
(None, Some(_)) => ["doc"].as_slice(),
(Some(_), None) => ["context"].as_slice(),
(Some(_), Some(_)) => ["context", "doc"].as_slice(),
};
Box::new(keys.iter().map(|s| KStringCow::from_static(s)))
}
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(
self.context
.as_ref()
.map(|context| context.as_value())
.into_iter()
.chain(self.document.map(|document| document.as_value()).into_iter()),
)
}
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(self.keys().zip(self.values()))
}
fn contains_key(&self, index: &str) -> bool {
index == "context" || index == "doc"
}
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
match index {
"context" => self.context.as_ref().map(|context| context.as_value()),
"doc" => self.document.as_ref().map(|doc| doc.as_value()),
_ => None,
}
}
}
impl<'a> ValueView for Context<'a> {
fn as_debug(&self) -> &dyn std::fmt::Debug {
self
}
fn render(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
}
fn source(&self) -> liquid::model::DisplayCow<'_> {
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
}
fn type_name(&self) -> &'static str {
"object"
}
fn query_state(&self, state: liquid::model::State) -> bool {
match state {
State::Truthy => true,
State::DefaultValue | State::Empty | State::Blank => false,
}
}
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
let s = ObjectRender::new(self).to_string();
KStringCow::from_string(s)
}
fn to_value(&self) -> LiquidValue {
LiquidValue::Object(
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
)
}
fn as_object(&self) -> Option<&dyn ObjectView> {
Some(self)
}
}

View File

@ -1,13 +1,20 @@
use std::sync::Arc;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde_json::Value;
use crate::score_details::{self, ScoreDetails};
use crate::vector::Embedder;
use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult};
enum RecommendKind<'a> {
Id(DocumentId),
Prompt { prompt: &'a str, context: Option<Value>, id: Option<DocumentId> },
}
pub struct Recommend<'a> {
id: DocumentId,
kind: RecommendKind<'a>,
// this should be linked to the String in the query
filter: Option<Filter<'a>>,
offset: usize,
@ -19,7 +26,7 @@ pub struct Recommend<'a> {
}
impl<'a> Recommend<'a> {
pub fn new(
pub fn with_docid(
id: DocumentId,
offset: usize,
limit: usize,
@ -28,7 +35,39 @@ impl<'a> Recommend<'a> {
embedder_name: String,
embedder: Arc<Embedder>,
) -> Self {
Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder }
Self {
kind: RecommendKind::Id(id),
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
}
}
pub fn with_prompt(
prompt: &'a str,
id: Option<DocumentId>,
context: Option<Value>,
offset: usize,
limit: usize,
index: &'a Index,
rtxn: &'a heed::RoTxn<'a>,
embedder_name: String,
embedder: Arc<Embedder>,
) -> Self {
Self {
kind: RecommendKind::Prompt { prompt, context, id },
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
}
}
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
@ -62,16 +101,67 @@ impl<'a> Recommend<'a> {
let mut results = Vec::new();
/// FIXME: make id optional...
let id = match &self.kind {
RecommendKind::Id(id) => *id,
RecommendKind::Prompt { prompt, context, id } => id.unwrap(),
};
let personalization_vector = if let RecommendKind::Prompt { prompt, context, id } =
&self.kind
{
let fields_ids_map = self.index.fields_ids_map(self.rtxn)?;
let document = if let Some(id) = id {
Some(self.index.iter_documents(self.rtxn, std::iter::once(*id))?.next().unwrap()?.1)
} else {
None
};
let document = document
.map(|document| crate::prompt::Document::from_doc_obkv(document, &fields_ids_map));
let context =
crate::prompt::recommend::Context::new(document.as_ref(), context.clone());
/// FIXME: handle error bad template
let template =
liquid::ParserBuilder::new().stdlib().build().unwrap().parse(prompt).unwrap();
/// FIXME: handle error bad context
let rendered = template.render(&context).unwrap();
/// FIXME: handle embedding error
Some(self.embedder.embed_one(rendered).unwrap())
} else {
None
};
for reader in readers.iter() {
let nns_by_item = reader.nns_by_item(
self.rtxn,
self.id,
id,
self.limit + self.offset + 1,
None,
Some(&universe),
)?;
if let Some(mut nns_by_item) = nns_by_item {
results.append(&mut nns_by_item);
if let Some(nns_by_item) = nns_by_item {
let mut nns = match &personalization_vector {
Some(vector) => {
let candidates: RoaringBitmap =
nns_by_item.iter().map(|(docid, _)| docid).collect();
reader.nns_by_vector(
self.rtxn,
vector,
self.limit + self.offset + 1,
None,
Some(&candidates),
)?
}
None => nns_by_item,
};
results.append(&mut nns);
}
}