From 683b6afbfb98e8f5bc2fb3314e2ed9ddf911868c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 9 Oct 2019 14:20:37 +0200 Subject: [PATCH] Introduce a way to filter documents with a basic syntax --- meilidb-core/examples/from_file.rs | 30 ++++++++++++++++++++-- meilidb-core/src/serde/deserializer.rs | 4 +-- meilidb-core/src/store/documents_fields.rs | 2 +- meilidb-core/src/store/mod.rs | 24 +++++++++++++---- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/meilidb-core/examples/from_file.rs b/meilidb-core/examples/from_file.rs index f0fbc61be..513a0b642 100644 --- a/meilidb-core/examples/from_file.rs +++ b/meilidb-core/examples/from_file.rs @@ -53,6 +53,11 @@ struct SearchCommand { #[structopt(short = "C", long, default_value = "35")] char_context: usize, + /// A filter string that can be `!adult` or `adult` to + /// filter documents on this specfied field + #[structopt(short, long)] + filter: Option, + /// Fields that must be displayed. displayed_fields: Vec, } @@ -269,8 +274,29 @@ fn search_command(command: SearchCommand, database: Database) -> Result<(), Box< Ok(query) => { let start_total = Instant::now(); - let builder = index.query_builder(); - let documents = builder.query(&reader, &query, 0..command.number_results)?; + let documents = match command.filter { + Some(ref filter) => { + let filter = filter.as_str(); + let (positive, filter) = if filter.chars().next() == Some('!') { + (false, &filter[1..]) + } else { + (true, filter) + }; + + let attr = schema.attribute(&filter).expect("Could not find filtered attribute"); + + let builder = index.query_builder(); + let builder = builder.with_filter(|document_id| { + let string: String = index.document_attribute(&reader, document_id, attr).unwrap().unwrap(); + (string == "true") == positive + }); + builder.query(&reader, &query, 0..command.number_results)? + }, + None => { + let builder = index.query_builder(); + builder.query(&reader, &query, 0..command.number_results)? + } + }; let mut retrieve_duration = Duration::default(); diff --git a/meilidb-core/src/serde/deserializer.rs b/meilidb-core/src/serde/deserializer.rs index dda13892c..ebf008eb7 100644 --- a/meilidb-core/src/serde/deserializer.rs +++ b/meilidb-core/src/serde/deserializer.rs @@ -52,7 +52,7 @@ pub struct Deserializer<'a, R> { pub reader: &'a R, pub documents_fields: DocumentsFields, pub schema: &'a Schema, - pub fields: Option<&'a HashSet>, + pub attributes: Option<&'a HashSet>, } impl<'de, 'a, 'b, R: 'a> de::Deserializer<'de> for &'b mut Deserializer<'a, R> @@ -86,7 +86,7 @@ where R: rkv::Readable, }; let is_displayed = self.schema.props(attr).is_displayed(); - if is_displayed && self.fields.map_or(true, |f| f.contains(&attr)) { + if is_displayed && self.attributes.map_or(true, |f| f.contains(&attr)) { let attribute_name = self.schema.attribute_name(attr); Some((attribute_name, Value::new(value))) } else { diff --git a/meilidb-core/src/store/documents_fields.rs b/meilidb-core/src/store/documents_fields.rs index bc1033807..804508f05 100644 --- a/meilidb-core/src/store/documents_fields.rs +++ b/meilidb-core/src/store/documents_fields.rs @@ -74,7 +74,7 @@ impl DocumentsFields { Ok(count) } - pub fn document_field<'a>( + pub fn document_attribute<'a>( &self, reader: &'a impl rkv::Readable, document_id: DocumentId, diff --git a/meilidb-core/src/store/mod.rs b/meilidb-core/src/store/mod.rs index 8eb108b75..62f28a851 100644 --- a/meilidb-core/src/store/mod.rs +++ b/meilidb-core/src/store/mod.rs @@ -15,7 +15,7 @@ pub use self::updates::Updates; pub use self::updates_results::UpdatesResults; use std::collections::HashSet; -use meilidb_schema::Schema; +use meilidb_schema::{Schema, SchemaAttr}; use serde::de; use crate::{update, query_builder::QueryBuilder, DocumentId, MResult, Error}; use crate::serde::Deserializer; @@ -69,15 +69,15 @@ impl Index { pub fn document( &self, reader: &R, - fields: Option<&HashSet<&str>>, + attributes: Option<&HashSet<&str>>, document_id: DocumentId, ) -> MResult> { let schema = self.main.schema(reader)?; let schema = schema.ok_or(Error::SchemaMissing)?; - let fields = match fields { - Some(fields) => fields.into_iter().map(|name| schema.attribute(name)).collect(), + let attributes = match attributes { + Some(attributes) => attributes.into_iter().map(|name| schema.attribute(name)).collect(), None => None, }; @@ -86,7 +86,7 @@ impl Index { reader, documents_fields: self.documents_fields, schema: &schema, - fields: fields.as_ref(), + attributes: attributes.as_ref(), }; // TODO: currently we return an error if all document fields are missing, @@ -94,6 +94,20 @@ impl Index { Ok(T::deserialize(&mut deserializer).map(Some)?) } + pub fn document_attribute( + &self, + reader: &R, + document_id: DocumentId, + attribute: SchemaAttr, + ) -> MResult> + { + let bytes = self.documents_fields.document_attribute(reader, document_id, attribute)?; + match bytes { + Some(bytes) => Ok(Some(rmp_serde::from_read_ref(bytes)?)), + None => Ok(None), + } + } + pub fn schema_update(&self, mut writer: rkv::Writer, schema: Schema) -> MResult<()> { update::push_schema_update(&mut writer, self.updates, self.updates_results, schema)?; writer.commit()?;