From 3a998cf39c878ec49c708223430ae3885d596fc0 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Sat, 30 May 2020 19:56:57 +0200 Subject: [PATCH] Far better usage of rayon to fold indexed data --- Cargo.lock | 11 --- Cargo.toml | 1 - src/main.rs | 208 +++++++++++++++++++++++++++------------------------- 3 files changed, 107 insertions(+), 113 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 24008fc08..abc44ba56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -84,16 +84,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "crossbeam-channel" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cced8691919c02aac3cb0a1bc2e9b73d89e832bf9a06fc579d4e71b68a2da061" -dependencies = [ - "crossbeam-utils", - "maybe-uninit", -] - [[package]] name = "crossbeam-deque" version = "0.7.3" @@ -330,7 +320,6 @@ dependencies = [ "anyhow", "bitpacking", "byteorder 1.3.4", - "crossbeam-channel", "csv", "fst", "fxhash", diff --git a/Cargo.toml b/Cargo.toml index 51966ebf9..bf5f76152 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,6 @@ anyhow = "1.0.28" bitpacking = "0.8.2" byteorder = "1.3.4" roaring = "0.5.2" -crossbeam-channel = "0.4.2" csv = "1.1.3" fst = "0.4.3" fxhash = "0.2.1" diff --git a/src/main.rs b/src/main.rs index a2784cae3..553357612 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,16 @@ +use std::collections::hash_map::Entry; use std::collections::{HashMap, BTreeSet}; use std::convert::TryFrom; use std::fs::File; use std::hash::BuildHasherDefault; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::thread; use anyhow::{ensure, Context}; use roaring::RoaringBitmap; -use crossbeam_channel::{select, Sender, Receiver}; use fst::IntoStreamer; use fxhash::FxHasher32; -use heed::{EnvOpenOptions, Database}; +use heed::{EnvOpenOptions, PolyDatabase, Database}; use heed::types::*; use rayon::prelude::*; use slice_group_by::StrGroupBy; @@ -81,97 +80,67 @@ fn alphanumeric_tokens(string: &str) -> impl Iterator { string.linear_group_by_key(|c| c.is_alphanumeric()).filter(is_alphanumeric) } -enum MainKey { - WordsFst(fst::Set>), - Headers(Vec), +#[derive(Default)] +struct Indexed { + fst: fst::Set>, + postings_ids: FastMap4, + headers: Vec, + documents: Vec<(DocumentId, Vec)>, } -#[derive(Clone)] -struct DbSender { - main: Sender, - postings_ids: Sender<(SmallString32, RoaringBitmap)>, - documents: Sender<(DocumentId, Vec)>, -} +impl Indexed { + fn merge_with(mut self, mut other: Indexed) -> Indexed { -struct DbReceiver { - main: Receiver, - postings_ids: Receiver<(SmallString32, RoaringBitmap)>, - documents: Receiver<(DocumentId, Vec)>, -} + // Union of the two FSTs + let op = fst::set::OpBuilder::new() + .add(self.fst.into_stream()) + .add(other.fst.into_stream()) + .r#union(); -fn thread_channel() -> (DbSender, DbReceiver) { - let (sd_main, rc_main) = crossbeam_channel::bounded(4); - let (sd_postings, rc_postings) = crossbeam_channel::bounded(10); - let (sd_documents, rc_documents) = crossbeam_channel::bounded(10); + let mut build = fst::SetBuilder::memory(); + build.extend_stream(op.into_stream()).unwrap(); + let fst = build.into_set(); - let sender = DbSender { main: sd_main, postings_ids: sd_postings, documents: sd_documents }; - let receiver = DbReceiver { main: rc_main, postings_ids: rc_postings, documents: rc_documents }; + // Merge the postings by unions + for (word, mut postings) in other.postings_ids { + match self.postings_ids.entry(word) { + Entry::Occupied(mut entry) => { + let old = entry.get(); + postings.union_with(&old); + entry.insert(postings); + }, + Entry::Vacant(entry) => { + entry.insert(postings); + }, + } + } - (sender, receiver) -} + // assert headers are valid + assert_eq!(self.headers, other.headers); -fn writer_thread(env: heed::Env, receiver: DbReceiver) -> anyhow::Result<()> { - let main = env.create_poly_database(None)?; - let postings_ids: Database = env.create_database(Some("postings-ids"))?; - let documents: Database, ByteSlice> = env.create_database(Some("documents"))?; + // extend the documents + self.documents.append(&mut other.documents); - let mut wtxn = env.write_txn()?; - - loop { - select! { - recv(receiver.main) -> msg => { - let msg = match msg { - Err(_) => break, - Ok(msg) => msg, - }; - - match msg { - MainKey::WordsFst(new_fst) => { - let old_value = main.get::<_, Str, ByteSlice>(&wtxn, "words-fst")?; - let new_value = union_words_fst(b"words-fst", old_value, &new_fst) - .context("error while do a words-fst union")?; - main.put::<_, Str, ByteSlice>(&mut wtxn, "words-fst", &new_value)?; - }, - MainKey::Headers(headers) => { - if let Some(old_headers) = main.get::<_, Str, ByteSlice>(&wtxn, "headers")? { - ensure!(old_headers == &*headers, "headers differs from the previous ones"); - } - main.put::<_, Str, ByteSlice>(&mut wtxn, "headers", &headers)?; - }, - } - }, - recv(receiver.postings_ids) -> msg => { - let (word, postings) = match msg { - Err(_) => break, - Ok(msg) => msg, - }; - - let old_value = postings_ids.get(&wtxn, &word)?; - let new_value = union_postings_ids(word.as_bytes(), old_value, postings) - .context("error while do a words-fst union")?; - postings_ids.put(&mut wtxn, &word, &new_value)?; - }, - recv(receiver.documents) -> msg => { - let (id, content) = match msg { - Err(_) => break, - Ok(msg) => msg, - }; - documents.put(&mut wtxn, &BEU32::new(id), &content)?; - }, + Indexed { + fst, + postings_ids: self.postings_ids, + headers: self.headers, + documents: self.documents, } } - - wtxn.commit()?; - Ok(()) } -fn index_csv(tid: usize, db_sender: DbSender, mut rdr: csv::Reader) -> anyhow::Result { +fn index_csv( + tid: usize, + mut rdr: csv::Reader, +) -> anyhow::Result +{ const MAX_POSITION: usize = 1000; const MAX_ATTRIBUTES: usize = u32::max_value() as usize / MAX_POSITION; let mut document = csv::StringRecord::new(); - let mut new_postings_ids = FastMap4::default(); - let mut new_words = BTreeSet::default(); + let mut postings_ids = FastMap4::default(); + let mut documents = Vec::new(); let mut number_of_documents = 0; // Write the headers into a Vec of bytes. @@ -179,7 +148,6 @@ fn index_csv(tid: usize, db_sender: DbSender, mut rdr: csv::Reader) -> any let mut writer = csv::WriterBuilder::new().has_headers(false).from_writer(Vec::new()); writer.write_byte_record(headers.as_byte_record())?; let headers = writer.into_inner()?; - db_sender.main.send(MainKey::Headers(headers))?; while rdr.read_record(&mut document)? { let document_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst); @@ -188,7 +156,7 @@ fn index_csv(tid: usize, db_sender: DbSender, mut rdr: csv::Reader) -> any for (_attr, content) in document.iter().enumerate().take(MAX_ATTRIBUTES) { for (_pos, word) in alphanumeric_tokens(&content).enumerate().take(MAX_POSITION) { if !word.is_empty() && word.len() < 500 { // LMDB limits - new_postings_ids.entry(SmallString32::from(word)) + postings_ids.entry(SmallString32::from(word)) .or_insert_with(RoaringBitmap::new) .insert(document_id); } @@ -199,7 +167,7 @@ fn index_csv(tid: usize, db_sender: DbSender, mut rdr: csv::Reader) -> any let mut writer = csv::WriterBuilder::new().has_headers(false).from_writer(Vec::new()); writer.write_byte_record(document.as_byte_record())?; let document = writer.into_inner()?; - db_sender.documents.send((document_id, document))?; + documents.push((document_id, document)); number_of_documents += 1; if number_of_documents % 100000 == 0 { @@ -207,26 +175,57 @@ fn index_csv(tid: usize, db_sender: DbSender, mut rdr: csv::Reader) -> any } } - eprintln!("Start collecting the postings lists and words"); + eprintln!("Start collecting the words into an FST"); // We compute and store the postings list into the DB. - for (word, new_ids) in new_postings_ids { - db_sender.postings_ids.send((word.clone(), new_ids))?; - new_words.insert(word); + let mut new_words = BTreeSet::default(); + for (word, _new_ids) in &postings_ids { + new_words.insert(word.clone()); } - eprintln!("Finished collecting the postings lists and words"); + let new_words_fst = fst::Set::from_iter(new_words.iter().map(SmallString32::as_str))?; - eprintln!("Start merging the words-fst"); + eprintln!("Total number of documents seen so far is {}", ID_GENERATOR.load(Ordering::Relaxed)); - let new_words_fst = fst::Set::from_iter(new_words.iter().map(|s| s.as_str()))?; - drop(new_words); - db_sender.main.send(MainKey::WordsFst(new_words_fst))?; + Ok(Indexed { fst: new_words_fst, headers, postings_ids, documents }) +} - eprintln!("Finished merging the words-fst"); - eprintln!("Total number of documents seen is {}", ID_GENERATOR.load(Ordering::Relaxed)); +fn writer( + wtxn: &mut heed::RwTxn, + main: PolyDatabase, + postings_ids: Database, + documents: Database, ByteSlice>, + indexed: Indexed, +) -> anyhow::Result +{ + // Write and merge the words fst + let old_value = main.get::<_, Str, ByteSlice>(wtxn, "words-fst")?; + let new_value = union_words_fst(b"words-fst", old_value, &indexed.fst) + .context("error while do a words-fst union")?; + main.put::<_, Str, ByteSlice>(wtxn, "words-fst", &new_value)?; - Ok(number_of_documents) + // Write and merge the headers + if let Some(old_headers) = main.get::<_, Str, ByteSlice>(wtxn, "headers")? { + ensure!(old_headers == &*indexed.headers, "headers differs from the previous ones"); + } + main.put::<_, Str, ByteSlice>(wtxn, "headers", &indexed.headers)?; + + // Write and merge the postings lists + for (word, postings) in indexed.postings_ids { + let old_value = postings_ids.get(wtxn, word.as_str())?; + let new_value = union_postings_ids(word.as_bytes(), old_value, postings) + .context("error while do a words-fst union")?; + postings_ids.put(wtxn, &word, &new_value)?; + } + + let count = indexed.documents.len(); + + // Write the documents + for (id, content) in indexed.documents { + documents.put(wtxn, &BEU32::new(id), &content)?; + } + + Ok(count) } fn main() -> anyhow::Result<()> { @@ -239,22 +238,29 @@ fn main() -> anyhow::Result<()> { .max_dbs(5) .open(opt.database)?; - let (sender, receiver) = thread_channel(); - let writing_child = thread::spawn(move || writer_thread(env, receiver)); + let main = env.create_poly_database(None)?; + let postings_ids: Database = env.create_database(Some("postings-ids"))?; + let documents: Database, ByteSlice> = env.create_database(Some("documents"))?; let res = opt.files_to_index .into_par_iter() .enumerate() - .map(|(tid, path)| { + .try_fold(|| Indexed::default(), |acc, (tid, path)| { let rdr = csv::Reader::from_path(path)?; - index_csv(tid, sender.clone(), rdr) + let indexed = index_csv(tid, rdr)?; + Ok(acc.merge_with(indexed)) as anyhow::Result + }) + .map(|indexed| match indexed { + Ok(indexed) => { + let mut wtxn = env.write_txn()?; + let count = writer(&mut wtxn, main, postings_ids, documents, indexed)?; + wtxn.commit()?; + Ok(count) + }, + Err(e) => Err(e), }) .try_reduce(|| 0, |a, b| Ok(a + b)); - - eprintln!("witing the writing thread..."); - writing_child.join().unwrap().unwrap(); - println!("indexed {:?} documents", res); Ok(())