mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-30 00:34:26 +01:00
Introduce an optimized version of the euclidean distance function
This commit is contained in:
parent
268a9ef416
commit
5816008139
@ -1,6 +1,13 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use space::Metric;
|
use space::Metric;
|
||||||
|
|
||||||
|
#[cfg(any(
|
||||||
|
target_arch = "x86",
|
||||||
|
target_arch = "x86_64",
|
||||||
|
all(target_arch = "aarch64", target_feature = "neon")
|
||||||
|
))]
|
||||||
|
const MIN_DIM_SIZE_SIMD: usize = 16;
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
|
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
|
||||||
pub struct DotProduct;
|
pub struct DotProduct;
|
||||||
|
|
||||||
@ -26,9 +33,58 @@ impl Metric<Vec<f32>> for Euclidean {
|
|||||||
type Unit = u32;
|
type Unit = u32;
|
||||||
|
|
||||||
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
|
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
|
||||||
|
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
|
||||||
|
{
|
||||||
|
if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
|
||||||
|
let squared = unsafe { squared_euclid_neon(&a, &b) };
|
||||||
|
let dist = squared.sqrt();
|
||||||
|
debug_assert!(!dist.is_nan());
|
||||||
|
return dist.to_bits();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
|
let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum();
|
||||||
let dist = squared.sqrt();
|
let dist = squared.sqrt();
|
||||||
debug_assert!(!dist.is_nan());
|
debug_assert!(!dist.is_nan());
|
||||||
dist.to_bits()
|
dist.to_bits()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
use std::arch::aarch64::*;
|
||||||
|
|
||||||
|
#[cfg(target_feature = "neon")]
|
||||||
|
pub(crate) unsafe fn squared_euclid_neon(v1: &[f32], v2: &[f32]) -> f32 {
|
||||||
|
let n = v1.len();
|
||||||
|
let m = n - (n % 16);
|
||||||
|
let mut ptr1: *const f32 = v1.as_ptr();
|
||||||
|
let mut ptr2: *const f32 = v2.as_ptr();
|
||||||
|
let mut sum1 = vdupq_n_f32(0.);
|
||||||
|
let mut sum2 = vdupq_n_f32(0.);
|
||||||
|
let mut sum3 = vdupq_n_f32(0.);
|
||||||
|
let mut sum4 = vdupq_n_f32(0.);
|
||||||
|
|
||||||
|
let mut i: usize = 0;
|
||||||
|
while i < m {
|
||||||
|
let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2));
|
||||||
|
sum1 = vfmaq_f32(sum1, sub1, sub1);
|
||||||
|
|
||||||
|
let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4)));
|
||||||
|
sum2 = vfmaq_f32(sum2, sub2, sub2);
|
||||||
|
|
||||||
|
let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8)));
|
||||||
|
sum3 = vfmaq_f32(sum3, sub3, sub3);
|
||||||
|
|
||||||
|
let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12)));
|
||||||
|
sum4 = vfmaq_f32(sum4, sub4, sub4);
|
||||||
|
|
||||||
|
ptr1 = ptr1.add(16);
|
||||||
|
ptr2 = ptr2.add(16);
|
||||||
|
i += 16;
|
||||||
|
}
|
||||||
|
let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4);
|
||||||
|
for i in 0..n - m {
|
||||||
|
result += (*ptr1.add(i) - *ptr2.add(i)).powi(2);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user