diff --git a/src/bin/indexer.rs b/src/bin/indexer.rs index 3c8a4b100..9cad5382d 100644 --- a/src/bin/indexer.rs +++ b/src/bin/indexer.rs @@ -424,7 +424,7 @@ fn main() -> anyhow::Result<()> { .max_dbs(10) .open(&opt.database)?; - let index = Index::new(&env)?; + let mut index = Index::new(&env, &opt.database)?; let documents_path = opt.database.join("documents.mtbl"); let num_threads = rayon::current_num_threads(); @@ -499,13 +499,12 @@ fn main() -> anyhow::Result<()> { let mut builder = Merger::builder(docs_merge); builder.extend(docs_stores); builder.build().write_into(&mut writer)?; - Ok(writer.into_inner()?) as anyhow::Result<_> + Ok(writer.finish()?) as anyhow::Result<_> }); - let file = lmdb.and(mtbl)?; - let mmap = unsafe { Mmap::map(&file)? }; - let documents = Reader::new(mmap)?; - let count = documents.metadata().count_entries; + lmdb.and(mtbl)?; + index.refresh_documents()?; + let count = index.number_of_documents(); debug!("Wrote {} documents into LMDB", count); diff --git a/src/bin/search.rs b/src/bin/search.rs index 6bd29230f..9ecbc1676 100644 --- a/src/bin/search.rs +++ b/src/bin/search.rs @@ -1,4 +1,3 @@ -use std::fs::File; use std::io::{self, Write, BufRead}; use std::iter::once; use std::path::PathBuf; @@ -7,7 +6,6 @@ use std::time::Instant; use heed::EnvOpenOptions; use log::debug; use milli::Index; -use oxidized_mtbl::Reader; use structopt::StructOpt; #[cfg(target_os = "linux")] @@ -47,14 +45,7 @@ fn main() -> anyhow::Result<()> { .open(&opt.database)?; // Open the LMDB database. - let index = Index::new(&env)?; - - // Open the documents MTBL database. - let path = opt.database.join("documents.mtbl"); - let file = File::open(path)?; - let mmap = unsafe { memmap::Mmap::map(&file)? }; - let documents = Reader::new(mmap.as_ref())?; - + let index = Index::new(&env, opt.database)?; let rtxn = env.read_txn()?; let stdin = io::stdin(); @@ -72,15 +63,13 @@ fn main() -> anyhow::Result<()> { Some(headers) => headers, None => return Ok(()), }; + let documents = index.documents(documents_ids.iter().cloned())?; let mut stdout = io::stdout(); stdout.write_all(&headers)?; - for id in &documents_ids { - let id_bytes = id.to_be_bytes(); - if let Some(content) = documents.clone().get(&id_bytes)? { - stdout.write_all(content.as_ref())?; - } + for (_id, content) in documents { + stdout.write_all(&content)?; } debug!("Took {:.02?} to find {} documents", before.elapsed(), documents_ids.len()); diff --git a/src/bin/serve.rs b/src/bin/serve.rs index 5c2f9bf5d..0962ae9da 100644 --- a/src/bin/serve.rs +++ b/src/bin/serve.rs @@ -9,7 +9,6 @@ use std::time::Instant; use askama_warp::Template; use heed::EnvOpenOptions; -use oxidized_mtbl::Reader; use serde::Deserialize; use slice_group_by::StrGroupBy; use structopt::StructOpt; @@ -99,22 +98,13 @@ async fn main() -> anyhow::Result<()> { .open(&opt.database)?; // Open the LMDB database. - let index = Index::new(&env)?; - - // Open the documents MTBL database. - let path = opt.database.join("documents.mtbl"); - let file = File::open(path)?; - let mmap = unsafe { memmap::Mmap::map(&file)? }; - let mmap = TransitiveArc(Arc::new(mmap)); - let documents = Reader::new(mmap)?; + let index = Index::new(&env, &opt.database)?; // Retrieve the database the file stem (w/o the extension), // the disk file size and the number of documents in the database. let db_name = opt.database.file_stem().and_then(|s| s.to_str()).unwrap_or("").to_string(); let db_size = File::open(opt.database.join("data.mdb"))?.metadata()?.len() as usize; - - // Retrieve the documents count. - let docs_count = documents.metadata().count_entries; + let docs_count = index.number_of_documents(); // We run and wait on the HTTP server @@ -198,7 +188,6 @@ async fn main() -> anyhow::Result<()> { } let env_cloned = env.clone(); - let documents_cloned = documents.clone(); let disable_highlighting = opt.disable_highlighting; let query_route = warp::filters::method::post() .and(warp::path!("query")) @@ -213,13 +202,10 @@ async fn main() -> anyhow::Result<()> { if let Some(headers) = index.headers(&rtxn).unwrap() { // We write the headers body.extend_from_slice(headers); + let documents = index.documents(documents_ids).unwrap(); - for id in documents_ids { - let id_bytes = id.to_be_bytes(); - let content = documents_cloned.clone().get(&id_bytes).unwrap(); - let content = content.expect(&format!("could not find document {}", id)); + for (_id, content) in documents { let content = std::str::from_utf8(content.as_ref()).unwrap(); - let content = if disable_highlighting { Cow::from(content) } else { diff --git a/src/lib.rs b/src/lib.rs index fc7a4b7c9..398d17371 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,17 @@ mod best_proximity; mod heed_codec; mod iter_shortest_paths; mod query_tokens; +mod transitive_arc; use std::borrow::Cow; use std::collections::{HashSet, HashMap}; +use std::fs::{File, OpenOptions}; use std::hash::BuildHasherDefault; +use std::path::{Path, PathBuf}; +use std::sync::Arc; use std::time::Instant; +use anyhow::Context; use cow_utils::CowUtils; use fst::{IntoStreamer, Streamer}; use fxhash::{FxHasher32, FxHasher64}; @@ -15,12 +20,15 @@ use heed::types::*; use heed::{PolyDatabase, Database}; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; use log::debug; +use memmap::Mmap; use once_cell::sync::Lazy; +use oxidized_mtbl as omtbl; use roaring::RoaringBitmap; use self::best_proximity::BestProximity; use self::heed_codec::RoaringBitmapCodec; use self::query_tokens::{QueryTokens, QueryToken}; +use self::transitive_arc::TransitiveArc; // Building these factories is not free. static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); @@ -39,6 +47,8 @@ pub type Position = u32; #[derive(Clone)] pub struct Index { + // The database path, where the LMDB and MTBL files are. + path: PathBuf, /// Contains many different types (e.g. the documents CSV headers). pub main: PolyDatabase, /// A word and all the positions where it appears in the whole dataset. @@ -49,20 +59,40 @@ pub struct Index { pub prefix_word_position_docids: Database, /// Maps a word and an attribute (u32) to all the documents ids that it appears in. pub word_attribute_docids: Database, + /// The MTBL store that contains the documents content. + documents: omtbl::Reader>, } impl Index { - pub fn new(env: &heed::Env) -> heed::Result { + pub fn new>(env: &heed::Env, path: P) -> anyhow::Result { + let documents_path = path.as_ref().join("documents.mtbl"); + let mut documents = OpenOptions::new().create(true).write(true).read(true).open(documents_path)?; + // If the file is empty we must initialize it like an empty MTBL database. + if documents.metadata()?.len() == 0 { + omtbl::Writer::new(&mut documents).finish()?; + } + let documents = unsafe { memmap::Mmap::map(&documents)? }; + Ok(Index { + path: path.as_ref().to_path_buf(), main: env.create_poly_database(None)?, word_positions: env.create_database(Some("word-positions"))?, prefix_word_positions: env.create_database(Some("prefix-word-positions"))?, word_position_docids: env.create_database(Some("word-position-docids"))?, prefix_word_position_docids: env.create_database(Some("prefix-word-position-docids"))?, word_attribute_docids: env.create_database(Some("word-attribute-docids"))?, + documents: omtbl::Reader::new(TransitiveArc(Arc::new(documents)))?, }) } + pub fn refresh_documents(&mut self) -> anyhow::Result<()> { + let documents_path = self.path.join("documents.mtbl"); + let documents = File::open(&documents_path)?; + let documents = unsafe { memmap::Mmap::map(&documents)? }; + self.documents = omtbl::Reader::new(TransitiveArc(Arc::new(documents)))?; + Ok(()) + } + pub fn put_headers(&self, wtxn: &mut heed::RwTxn, headers: &[u8]) -> anyhow::Result<()> { Ok(self.main.put::<_, Str, ByteSlice>(wtxn, "headers", headers)?) } @@ -93,6 +123,21 @@ impl Index { } } + /// Returns a [`Vec`] of the requested documents. Returns an error if a document is missing. + pub fn documents>(&self, iter: I) -> anyhow::Result)>> { + iter.into_iter().map(|id| { + let key = id.to_be_bytes(); + let content = self.documents.clone().get(&key)?.with_context(|| format!("Could not find document {}.", id))?; + Ok((id, content.as_ref().to_vec())) + }) + .collect() + } + + /// Returns the number of documents indexed in the database. + pub fn number_of_documents(&self) -> usize { + self.documents.metadata().count_entries as usize + } + pub fn search(&self, rtxn: &heed::RoTxn, query: &str) -> anyhow::Result<(HashSet, Vec)> { let fst = match self.fst(rtxn)? { Some(fst) => fst, diff --git a/src/transitive_arc.rs b/src/transitive_arc.rs new file mode 100644 index 000000000..c25ee3e63 --- /dev/null +++ b/src/transitive_arc.rs @@ -0,0 +1,16 @@ +use std::sync::Arc; + +/// An `Arc<[u8]>` that is transitive over `AsRef<[u8]>`. +pub struct TransitiveArc(pub Arc); + +impl> AsRef<[u8]> for TransitiveArc { + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_ref() + } +} + +impl Clone for TransitiveArc { + fn clone(&self) -> TransitiveArc { + TransitiveArc(self.0.clone()) + } +}