use enum_iterator::IntoEnumIterator; use std::borrow::Cow; use std::cmp::Reverse; use std::convert::TryFrom; use std::convert::TryInto; use std::fs::create_dir_all; use std::path::Path; use std::str; use std::sync::Arc; use milli::heed::types::{ByteSlice, DecodeIgnore, SerdeJson}; use milli::heed::{Database, Env, EnvOpenOptions, RwTxn}; use time::OffsetDateTime; use super::error::Result; use super::{Action, Key}; const AUTH_STORE_SIZE: usize = 1_073_741_824; //1GiB pub const KEY_ID_LENGTH: usize = 8; const AUTH_DB_PATH: &str = "auth"; const KEY_DB_NAME: &str = "api-keys"; const KEY_ID_ACTION_INDEX_EXPIRATION_DB_NAME: &str = "keyid-action-index-expiration"; pub type KeyId = [u8; KEY_ID_LENGTH]; #[derive(Clone)] pub struct HeedAuthStore { env: Arc, keys: Database>, action_keyid_index_expiration: Database>>, should_close_on_drop: bool, } impl Drop for HeedAuthStore { fn drop(&mut self) { if self.should_close_on_drop && Arc::strong_count(&self.env) == 1 { self.env.as_ref().clone().prepare_for_closing(); } } } pub fn open_auth_store_env(path: &Path) -> milli::heed::Result { let mut options = EnvOpenOptions::new(); options.map_size(AUTH_STORE_SIZE); // 1GB options.max_dbs(2); options.open(path) } impl HeedAuthStore { pub fn new(path: impl AsRef) -> Result { let path = path.as_ref().join(AUTH_DB_PATH); create_dir_all(&path)?; let env = Arc::new(open_auth_store_env(path.as_ref())?); let keys = env.create_database(Some(KEY_DB_NAME))?; let action_keyid_index_expiration = env.create_database(Some(KEY_ID_ACTION_INDEX_EXPIRATION_DB_NAME))?; Ok(Self { env, keys, action_keyid_index_expiration, should_close_on_drop: true, }) } pub fn set_drop_on_close(&mut self, v: bool) { self.should_close_on_drop = v; } pub fn is_empty(&self) -> Result { let rtxn = self.env.read_txn()?; Ok(self.keys.len(&rtxn)? == 0) } pub fn put_api_key(&self, key: Key) -> Result { let mut wtxn = self.env.write_txn()?; self.keys.put(&mut wtxn, &key.id, &key)?; let id = key.id; // delete key from inverted database before refilling it. self.delete_key_from_inverted_db(&mut wtxn, &id)?; // create inverted database. let db = self.action_keyid_index_expiration; let actions = if key.actions.contains(&Action::All) { // if key.actions contains All, we iterate over all actions. Action::into_enum_iter().collect() } else { key.actions.clone() }; let no_index_restriction = key.indexes.contains(&"*".to_owned()); for action in actions { if no_index_restriction { // If there is no index restriction we put None. db.put(&mut wtxn, &(&id, &action, None), &key.expires_at)?; } else { // else we create a key for each index. for index in key.indexes.iter() { db.put( &mut wtxn, &(&id, &action, Some(index.as_bytes())), &key.expires_at, )?; } } } wtxn.commit()?; Ok(key) } pub fn get_api_key(&self, key: impl AsRef) -> Result> { let rtxn = self.env.read_txn()?; match self.get_key_id(key.as_ref().as_bytes()) { Some(id) => self.keys.get(&rtxn, &id).map_err(|e| e.into()), None => Ok(None), } } pub fn delete_api_key(&self, key: impl AsRef) -> Result { let mut wtxn = self.env.write_txn()?; let existing = match self.get_key_id(key.as_ref().as_bytes()) { Some(id) => { let existing = self.keys.delete(&mut wtxn, &id)?; self.delete_key_from_inverted_db(&mut wtxn, &id)?; existing } None => false, }; wtxn.commit()?; Ok(existing) } pub fn list_api_keys(&self) -> Result> { let mut list = Vec::new(); let rtxn = self.env.read_txn()?; for result in self.keys.remap_key_type::().iter(&rtxn)? { let (_, content) = result?; list.push(content); } list.sort_unstable_by_key(|k| Reverse(k.created_at)); Ok(list) } pub fn get_expiration_date( &self, key: &[u8], action: Action, index: Option<&[u8]>, ) -> Result>> { let rtxn = self.env.read_txn()?; match self.get_key_id(key) { Some(id) => { let tuple = (&id, &action, index); Ok(self.action_keyid_index_expiration.get(&rtxn, &tuple)?) } None => Ok(None), } } pub fn prefix_first_expiration_date( &self, key: &[u8], action: Action, ) -> Result>> { let rtxn = self.env.read_txn()?; match self.get_key_id(key) { Some(id) => { let tuple = (&id, &action, None); Ok(self .action_keyid_index_expiration .prefix_iter(&rtxn, &tuple)? .next() .transpose()? .map(|(_, expiration)| expiration)) } None => Ok(None), } } pub fn get_key_id(&self, key: &[u8]) -> Option { try_split_array_at::<_, KEY_ID_LENGTH>(key).map(|(id, _)| *id) } fn delete_key_from_inverted_db(&self, wtxn: &mut RwTxn, key: &KeyId) -> Result<()> { let mut iter = self .action_keyid_index_expiration .remap_types::() .prefix_iter_mut(wtxn, key)?; while iter.next().transpose()?.is_some() { // safety: we don't keep references from inside the LMDB database. unsafe { iter.del_current()? }; } Ok(()) } } /// Codec allowing to retrieve the expiration date of an action, /// optionnally on a spcific index, for a given key. pub struct KeyIdActionCodec; impl<'a> milli::heed::BytesDecode<'a> for KeyIdActionCodec { type DItem = (KeyId, Action, Option<&'a [u8]>); fn bytes_decode(bytes: &'a [u8]) -> Option { let (key_id, action_bytes) = try_split_array_at(bytes)?; let (action_bytes, index) = match try_split_array_at(action_bytes)? { (action, []) => (action, None), (action, index) => (action, Some(index)), }; let action = Action::from_repr(u8::from_be_bytes(*action_bytes))?; Some((*key_id, action, index)) } } impl<'a> milli::heed::BytesEncode<'a> for KeyIdActionCodec { type EItem = (&'a KeyId, &'a Action, Option<&'a [u8]>); fn bytes_encode((key_id, action, index): &Self::EItem) -> Option> { let mut bytes = Vec::new(); bytes.extend_from_slice(*key_id); let action_bytes = u8::to_be_bytes(action.repr()); bytes.extend_from_slice(&action_bytes); if let Some(index) = index { bytes.extend_from_slice(index); } Some(Cow::Owned(bytes)) } } /// Divides one slice into two at an index, returns `None` if mid is out of bounds. pub fn try_split_at(slice: &[T], mid: usize) -> Option<(&[T], &[T])> { if mid <= slice.len() { Some(slice.split_at(mid)) } else { None } } /// Divides one slice into an array and the tail at an index, /// returns `None` if `N` is out of bounds. pub fn try_split_array_at(slice: &[T]) -> Option<(&[T; N], &[T])> where [T; N]: for<'a> TryFrom<&'a [T]>, { let (head, tail) = try_split_at(slice, N)?; let head = head.try_into().ok()?; Some((head, tail)) }