zebra/
distance.rs

1use std::ops::Deref;
2
3use crate::{Embedding, EmbeddingPrecision};
4use distances::vectors::{
5    bray_curtis, canberra, chebyshev, cosine, euclidean, euclidean_sq, hamming, l3_norm, l4_norm,
6    manhattan, minkowski, minkowski_p,
7};
8use serde::{Deserialize, Serialize};
9use simsimd::SpatialSimilarity;
10use space::Metric;
11
12/// The data type representing the distance between two embeddings.
13pub type DistanceUnit = u64;
14
15#[derive(Default, Debug, Clone, Serialize, Deserialize)]
16/// The cosine distance metric.
17pub struct CosineDistance<const N: usize>;
18
19impl<const N: usize> Metric<Embedding<N>> for CosineDistance<N> {
20    type Unit = DistanceUnit;
21    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
22        // Use SIMD if vectors are of same length; otherwise, find distance after truncating longer vector so lengths match
23        EmbeddingPrecision::cosine(a.deref(), b.deref())
24            .map(|c| 1.0 - c)
25            .map(|x| x.to_bits())
26            .unwrap_or(
27                cosine::<_, EmbeddingPrecision>(a.deref(), b.deref())
28                    .to_bits()
29                    .into(),
30            )
31    }
32}
33
34#[derive(Default, Debug, Clone, Serialize, Deserialize)]
35/// The L2-squared distance metric.
36pub struct L2SquaredDistance<const N: usize>;
37
38impl<const N: usize> Metric<Embedding<N>> for L2SquaredDistance<N> {
39    type Unit = DistanceUnit;
40    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
41        EmbeddingPrecision::sqeuclidean(a.deref(), b.deref())
42            .map(|x| x.to_bits())
43            .unwrap_or(
44                euclidean_sq::<_, EmbeddingPrecision>(a.deref(), b.deref())
45                    .to_bits()
46                    .into(),
47            )
48    }
49}
50
51#[derive(Default, Debug, Clone, Serialize, Deserialize)]
52/// The Chebyshev distance metric.
53pub struct ChebyshevDistance<const N: usize>;
54
55impl<const N: usize> Metric<Embedding<N>> for ChebyshevDistance<N> {
56    type Unit = DistanceUnit;
57    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
58        let chebyshev_distance = chebyshev(a.deref(), b.deref());
59        chebyshev_distance.to_bits().into()
60    }
61}
62
63#[derive(Default, Debug, Clone, Serialize, Deserialize)]
64/// The Canberra distance metric.
65pub struct CanberraDistance<const N: usize>;
66
67impl<const N: usize> Metric<Embedding<N>> for CanberraDistance<N> {
68    type Unit = DistanceUnit;
69    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
70        let canberra_distance: EmbeddingPrecision = canberra(a.deref(), b.deref());
71        canberra_distance.to_bits().into()
72    }
73}
74
75#[derive(Default, Debug, Clone, Serialize, Deserialize)]
76/// The Bray-Curtis distance metric.
77pub struct BrayCurtisDistance<const N: usize>;
78
79impl<const N: usize> Metric<Embedding<N>> for BrayCurtisDistance<N> {
80    type Unit = DistanceUnit;
81    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
82        let bray_curtis_distance: EmbeddingPrecision = bray_curtis(a.deref(), b.deref());
83        bray_curtis_distance.to_bits().into()
84    }
85}
86
87#[derive(Default, Debug, Clone, Serialize, Deserialize)]
88/// The Manhattan distance metric.
89pub struct ManhattanDistance<const N: usize>;
90
91impl<const N: usize> Metric<Embedding<N>> for ManhattanDistance<N> {
92    type Unit = DistanceUnit;
93    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
94        let manhattan_distance: EmbeddingPrecision = manhattan(a.deref(), b.deref());
95        manhattan_distance.to_bits().into()
96    }
97}
98
99#[derive(Default, Debug, Clone, Serialize, Deserialize)]
100/// The L2 distance metric.
101pub struct L2Distance<const N: usize>;
102
103impl<const N: usize> Metric<Embedding<N>> for L2Distance<N> {
104    type Unit = DistanceUnit;
105    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
106        EmbeddingPrecision::euclidean(a.deref(), b.deref())
107            .map(|x| x.to_bits())
108            .unwrap_or(
109                euclidean::<_, EmbeddingPrecision>(a.deref(), b.deref())
110                    .to_bits()
111                    .into(),
112            )
113    }
114}
115
116#[derive(Default, Debug, Clone, Serialize, Deserialize)]
117/// The L3 distance metric.
118pub struct L3Distance<const N: usize>;
119
120impl<const N: usize> Metric<Embedding<N>> for L3Distance<N> {
121    type Unit = DistanceUnit;
122    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
123        let l3_distance: EmbeddingPrecision = l3_norm(a.deref(), b.deref());
124        l3_distance.to_bits().into()
125    }
126}
127
128#[derive(Default, Debug, Clone, Serialize, Deserialize)]
129/// The L4 distance metric.
130pub struct L4Distance<const N: usize>;
131
132impl<const N: usize> Metric<Embedding<N>> for L4Distance<N> {
133    type Unit = DistanceUnit;
134    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
135        let l4_distance: EmbeddingPrecision = l4_norm(a.deref(), b.deref());
136        l4_distance.to_bits().into()
137    }
138}
139
140#[derive(Default, Debug, Clone, Serialize, Deserialize)]
141/// The Hamming distance metric.
142pub struct HammingDistance<const N: usize>;
143
144impl<const N: usize> Metric<Embedding<N>> for HammingDistance<N> {
145    type Unit = DistanceUnit;
146    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
147        let a_to_bits: Vec<u8> = a.iter().map(|x| x.to_bits() as u8).collect();
148        let b_to_bits: Vec<u8> = b.iter().map(|x| x.to_bits() as u8).collect();
149        match a.len() == b.len() {
150            true => hamming_bitwise_fast::hamming_bitwise_fast(
151                a_to_bits.as_slice(),
152                b_to_bits.as_slice(),
153            )
154            .into(),
155            false => hamming::<_, u32>(a_to_bits.as_slice(), b_to_bits.as_slice()).into(),
156        }
157    }
158}
159
160#[derive(Default, Debug, Clone, Serialize, Deserialize)]
161/// The Minkowski distance metric.
162pub struct MinkowskiDistance<const N: usize> {
163    /// The power of the Minkowski distance.
164    pub power: i32,
165}
166
167impl<const N: usize> Metric<Embedding<N>> for MinkowskiDistance<N> {
168    type Unit = DistanceUnit;
169    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
170        let metric = minkowski(self.power);
171        let distance: EmbeddingPrecision = metric(a.deref(), b.deref());
172        distance.to_bits().into()
173    }
174}
175
176#[derive(Default, Debug, Clone, Serialize, Deserialize)]
177/// The p-norm distance metric.
178pub struct PNormDistance<const N: usize> {
179    /// The power of the distance metric.
180    pub power: i32,
181}
182
183impl<const N: usize> Metric<Embedding<N>> for PNormDistance<N> {
184    type Unit = DistanceUnit;
185    fn distance(&self, a: &Embedding<N>, b: &Embedding<N>) -> Self::Unit {
186        let metric = minkowski_p(self.power);
187        let distance: EmbeddingPrecision = metric(a.deref(), b.deref());
188        distance.to_bits().into()
189    }
190}