diff --git a/milli/src/search/criteria/attribute.rs b/milli/src/search/criteria/attribute.rs index 12c6b36b8..af3e08af1 100644 --- a/milli/src/search/criteria/attribute.rs +++ b/milli/src/search/criteria/attribute.rs @@ -326,30 +326,25 @@ impl<'t, 'q> QueryLevelIterator<'t, 'q> { struct Branch<'t, 'q> { query_level_iterator: QueryLevelIterator<'t, 'q>, - last_result: Option<(u32, u32, RoaringBitmap)>, + last_result: (u32, u32, RoaringBitmap), tree_level: TreeLevel, branch_size: u32, } impl<'t, 'q> Branch<'t, 'q> { fn cmp(&self, other: &Self) -> Ordering { - fn compute_rank(left: u32, branch_size: u32) -> u32 { left.saturating_sub((0..branch_size).sum()) / branch_size } - match (&self.last_result, &other.last_result) { - (Some((s_left, _, _)), Some((o_left, _, _))) => { - // we compute a rank from the left interval. - let self_rank = compute_rank(*s_left, self.branch_size); - let other_rank = compute_rank(*o_left, other.branch_size); - let left_cmp = self_rank.cmp(&other_rank).reverse(); - // on level: higher is better, - // we want to reduce highest levels first. - let level_cmp = self.tree_level.cmp(&other.tree_level); + let compute_rank = |left: u32, branch_size: u32| left.saturating_sub((0..branch_size).sum()) / branch_size; + let (s_left, _, _) = self.last_result; + let (o_left, _, _) = other.last_result; + // we compute a rank from the left interval. + let self_rank = compute_rank(s_left, self.branch_size); + let other_rank = compute_rank(o_left, other.branch_size); + let left_cmp = self_rank.cmp(&other_rank).reverse(); + // on level: higher is better, + // we want to reduce highest levels first. + let level_cmp = self.tree_level.cmp(&other.tree_level); - left_cmp.then(level_cmp) - }, - (Some(_), None) => Ordering::Greater, - (None, Some(_)) => Ordering::Less, - (None, None) => Ordering::Equal, - } + left_cmp.then(level_cmp) } } @@ -407,13 +402,15 @@ fn initialize_query_level_iterators<'t, 'q>( if let Some(mut folded_query_level_iterators) = folded_query_level_iterators { let (tree_level, last_result) = folded_query_level_iterators.next()?; - let branch = Branch { - last_result, - tree_level, - query_level_iterator: folded_query_level_iterators, - branch_size: branch.len() as u32, - }; - positions.push(branch); + if let Some(last_result) = last_result { + let branch = Branch { + last_result, + tree_level, + query_level_iterator: folded_query_level_iterators, + branch_size: branch.len() as u32, + }; + positions.push(branch); + } } } @@ -433,28 +430,35 @@ fn set_compute_candidates<'t>( while let Some(mut branch) = branches_heap.peek_mut() { let is_lowest_level = branch.tree_level == lowest_level; - match branch.last_result.as_mut() { - Some((_, _, candidates)) => { - candidates.intersect_with(&allowed_candidates); - if candidates.len() > 0 && is_lowest_level { - // we have candidates, but we can't dig deeper, return candidates. - final_candidates = Some(std::mem::take(candidates)); - break; - } else if candidates.len() > 0 { - // we have candidates, lets dig deeper in levels. - let mut query_level_iterator = branch.query_level_iterator.dig(ctx)?; - let (tree_level, last_result) = query_level_iterator.next()?; + let (_, _, candidates) = &mut branch.last_result; + candidates.intersect_with(&allowed_candidates); + if candidates.is_empty() { + // we don't have candidates, get next interval. + match branch.query_level_iterator.next()? { + (_, Some(last_result)) => { + branch.last_result = last_result; + }, + // TODO clean up this + (_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); }, + } + + } + else if is_lowest_level { + // we have candidates, but we can't dig deeper, return candidates. + final_candidates = Some(take(candidates)); + break; + } else { + // we have candidates, lets dig deeper in levels. + let mut query_level_iterator = branch.query_level_iterator.dig(ctx)?; + match query_level_iterator.next()? { + (tree_level, Some(last_result)) => { branch.query_level_iterator = query_level_iterator; branch.tree_level = tree_level; branch.last_result = last_result; - } else { - // we don't have candidates, get next interval. - let (_, last_result) = branch.query_level_iterator.next()?; - branch.last_result = last_result; - } - }, - // None = no candidates to find. - None => break, + }, + // TODO clean up this + (_, None) => { std::collections::binary_heap::PeekMut::<'_, Branch<'_, '_>>::pop(branch); }, + } } }