在向量搜索引擎中,距离度量是判断两个向量相似程度的核心方法。不同的距离度量适用于不同的应用场景:
一个有效的距离度量应满足以下性质:
package com.jvector.core;
/**
* 距离度量接口
* 定义计算两个向量之间距离的标准方法
*/
publicinterface DistanceMetric {
/**
* 计算两个向量之间的距离
* @param a 向量a的数据数组
* @param b 向量b的数据数组
* @return 距离值,值越小表示越相似
*/
float distance(float[] a, float[] b);
/**
* 获取距离度量的名称
* @return 度量名称
*/
String getName();
/**
* 检查距离度量是否为相似度(值越大越相似)
* @return true表示相似度,false表示距离
*/
default boolean isSimilarity() {
returnfalse;
}
}
欧几里得距离是最常用的距离度量,计算两点在空间中的直线距离:
²
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 欧几里得距离(L2距离)实现
*/
publicclass EuclideanDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException(
String.format("Vector dimensions must match: %d vs %d", a.length, b.length));
}
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return (float) Math.sqrt(sum);
}
@Override
public String getName() {
return"euclidean";
}
/**
* 计算平方距离(避免开方运算,提高性能)
* @param a 向量a
* @param b 向量b
* @return 平方距离
*/
public float squaredDistance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum;
}
}
/**
* SIMD优化的欧几里得距离计算
*/
publicclass OptimizedEuclideanDistance extends EuclideanDistance {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
return (float) Math.sqrt(squaredDistanceOptimized(a, b));
}
/**
* 使用循环展开优化的平方距离计算
*/
private float squaredDistanceOptimized(float[] a, float[] b) {
int len = a.length;
float sum = 0.0f;
int i = 0;
// 循环展开,一次处理4个元素
for (; i < len - 3; i += 4) {
float diff0 = a[i] - b[i];
float diff1 = a[i + 1] - b[i + 1];
float diff2 = a[i + 2] - b[i + 2];
float diff3 = a[i + 3] - b[i + 3];
sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
}
// 处理剩余元素
for (; i < len; i++) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return sum;
}
}
余弦距离基于余弦相似度,主要关注向量的方向而非大小:
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 余弦距离实现
* 适用于文本相似性、推荐系统等场景
*/
publicclass CosineDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float dotProduct = 0.0f;
float normA = 0.0f;
float normB = 0.0f;
// 一次遍历计算所有必要值
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
// 处理零向量情况
if (normA == 0.0f || normB == 0.0f) {
return1.0f; // 最大距离
}
float similarity = dotProduct / (float) (Math.sqrt(normA) * Math.sqrt(normB));
// 确保相似度在[-1, 1]范围内
similarity = Math.max(-1.0f, Math.min(1.0f, similarity));
return1.0f - similarity;
}
@Override
public String getName() {
return"cosine";
}
/**
* 直接计算余弦相似度
* @param a 向量a
* @param b 向量b
* @return 余弦相似度
*/
public float similarity(float[] a, float[] b) {
return1.0f - distance(a, b);
}
}
对于频繁计算余弦距离的场景,可以预先归一化向量:
/**
* 预归一化余弦距离计算
*/
publicclass NormalizedCosineDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
// 假设输入向量已经归一化,直接计算点积
float dotProduct = 0.0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
}
return1.0f - dotProduct;
}
@Override
public String getName() {
return"normalized_cosine";
}
/**
* 归一化向量
* @param vector 原始向量
* @return 归一化后的向量
*/
publicstaticfloat[] normalize(float[] vector) {
float norm = 0.0f;
for (float v : vector) {
norm += v * v;
}
norm = (float) Math.sqrt(norm);
if (norm == 0.0f) {
return vector.clone();
}
float[] normalized = newfloat[vector.length];
for (int i = 0; i < vector.length; i++) {
normalized[i] = vector[i] / norm;
}
return normalized;
}
}
曼哈顿距离计算向量在各维度上差值的绝对值之和:
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 曼哈顿距离(L1距离)实现
* 对异常值不敏感,适用于高维稀疏数据
*/
publicclass ManhattanDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
sum += Math.abs(a[i] - b[i]);
}
return sum;
}
@Override
public String getName() {
return"manhattan";
}
}
切比雪夫距离取各维度差值绝对值的最大值:
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 切比雪夫距离(L∞距离)实现
*/
publicclass ChebyshevDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float maxDiff = 0.0f;
for (int i = 0; i < a.length; i++) {
float diff = Math.abs(a[i] - b[i]);
maxDiff = Math.max(maxDiff, diff);
}
return maxDiff;
}
@Override
public String getName() {
return"chebyshev";
}
}
闵可夫斯基距离是Lp距离的通用形式:
当p=1时为曼哈顿距离,p=2时为欧几里得距离。
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 闵可夫斯基距离实现
* 支持任意p值的Lp距离计算
*/
publicclass MinkowskiDistance implements DistanceMetric {
privatefinaldouble p;
public MinkowskiDistance(double p) {
if (p <= 0) {
thrownew IllegalArgumentException("p must be positive");
}
this.p = p;
}
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
if (p == 1.0) {
// 曼哈顿距离优化
return manhattanDistance(a, b);
} elseif (p == 2.0) {
// 欧几里得距离优化
return euclideanDistance(a, b);
} elseif (Double.isInfinite(p)) {
// 切比雪夫距离
return chebyshevDistance(a, b);
}
// 通用闵可夫斯基距离
double sum = 0.0;
for (int i = 0; i < a.length; i++) {
sum += Math.pow(Math.abs(a[i] - b[i]), p);
}
return (float) Math.pow(sum, 1.0 / p);
}
@Override
public String getName() {
return"minkowski_" + p;
}
private float manhattanDistance(float[] a, float[] b) {
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
sum += Math.abs(a[i] - b[i]);
}
return sum;
}
private float euclideanDistance(float[] a, float[] b) {
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return (float) Math.sqrt(sum);
}
private float chebyshevDistance(float[] a, float[] b) {
float max = 0.0f;
for (int i = 0; i < a.length; i++) {
max = Math.max(max, Math.abs(a[i] - b[i]));
}
return max;
}
}
汉明距离用于计算两个二进制向量中不同位的数量:
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 汉明距离实现
* 适用于二进制特征、分类数据等离散场景
*/
publicclass HammingDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
int differences = 0;
for (int i = 0; i < a.length; i++) {
if (a[i] != b[i]) {
differences++;
}
}
return differences;
}
@Override
public String getName() {
return"hamming";
}
/**
* 计算标准化汉明距离(0-1之间)
* @param a 向量a
* @param b 向量b
* @return 标准化汉明距离
*/
public float normalizedDistance(float[] a, float[] b) {
return distance(a, b) / a.length;
}
/**
* 针对二进制数据优化的汉明距离计算
* @param a 二进制向量a
* @param b 二进制向量b
* @return 汉明距离
*/
public int binaryHammingDistance(boolean[] a, boolean[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
int differences = 0;
for (int i = 0; i < a.length; i++) {
if (a[i] != b[i]) {
differences++;
}
}
return differences;
}
}
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 雅卡德距离实现
* 适用于集合相似性、二进制特征等
*/
publicclass JaccardDistance implements DistanceMetric {
privatefinalfloat threshold;
public JaccardDistance() {
this(0.0f);
}
public JaccardDistance(float threshold) {
this.threshold = threshold;
}
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
int intersection = 0;
int union = 0;
for (int i = 0; i < a.length; i++) {
boolean aPresent = a[i] > threshold;
boolean bPresent = b[i] > threshold;
if (aPresent && bPresent) {
intersection++;
}
if (aPresent || bPresent) {
union++;
}
}
if (union == 0) {
return0.0f; // 两个空集的距离为0
}
float similarity = (float) intersection / union;
return1.0f - similarity;
}
@Override
public String getName() {
return"jaccard";
}
}
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 布雷-柯蒂斯距离实现
* 归一化的曼哈顿距离,适用于生态学、化学分析等领域
*/
publicclass BrayCurtisDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float numerator = 0.0f;
float denominator = 0.0f;
for (int i = 0; i < a.length; i++) {
numerator += Math.abs(a[i] - b[i]);
denominator += Math.abs(a[i] + b[i]);
}
if (denominator == 0.0f) {
return0.0f; // 两个零向量的距离为0
}
return numerator / denominator;
}
@Override
public String getName() {
return"braycurtis";
}
}
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 堪培拉距离实现
* 曼哈顿距离的加权版本,对接近零的值更敏感
*/
publicclass CanberraDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float sum = 0.0f;
for (int i = 0; i < a.length; i++) {
float numerator = Math.abs(a[i] - b[i]);
float denominator = Math.abs(a[i]) + Math.abs(b[i]);
if (denominator != 0.0f) {
sum += numerator / denominator;
}
// 如果分母为0(两个值都是0),贡献为0
}
return sum;
}
@Override
public String getName() {
return"canberra";
}
}
内积距离用于最大内积搜索(MIPS),计算负点积:
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
/**
* 内积距离实现
* 用于最大内积搜索(Maximum Inner Product Search, MIPS)
*/
publicclass InnerProductDistance implements DistanceMetric {
@Override
public float distance(float[] a, float[] b) {
if (a.length != b.length) {
thrownew IllegalArgumentException("Vector dimensions must match");
}
float innerProduct = 0.0f;
for (int i = 0; i < a.length; i++) {
innerProduct += a[i] * b[i];
}
return -innerProduct; // 负内积作为距离
}
@Override
public String getName() {
return"innerproduct";
}
@Override
public boolean isSimilarity() {
returnfalse; // 虽然内积本身是相似度,但这里返回的是负值
}
/**
* 直接计算内积
* @param a 向量a
* @param b 向量b
* @return 内积值
*/
public float innerProduct(float[] a, float[] b) {
return -distance(a, b);
}
}
package com.jvector.core.metric;
import com.jvector.core.DistanceMetric;
import java.util.HashMap;
import java.util.Map;
/**
* 距离度量工厂类
* 负责创建和管理各种距离度量实例
*/
publicclass DistanceMetricFactory {
privatestaticfinal Map<String, Class<? extends DistanceMetric>> METRIC_REGISTRY = new HashMap<>();
static {
// 注册所有内置距离度量
registerMetric("euclidean", EuclideanDistance.class);
registerMetric("l2", EuclideanDistance.class);
registerMetric("cosine", CosineDistance.class);
registerMetric("manhattan", ManhattanDistance.class);
registerMetric("l1", ManhattanDistance.class);
registerMetric("chebyshev", ChebyshevDistance.class);
registerMetric("linf", ChebyshevDistance.class);
registerMetric("hamming", HammingDistance.class);
registerMetric("jaccard", JaccardDistance.class);
registerMetric("braycurtis", BrayCurtisDistance.class);
registerMetric("canberra", CanberraDistance.class);
registerMetric("innerproduct", InnerProductDistance.class);
registerMetric("ip", InnerProductDistance.class);
}
/**
* 注册距离度量
* @param name 度量名称
* @param clazz 度量类
*/
public static void registerMetric(String name, Class<? extends DistanceMetric> clazz) {
METRIC_REGISTRY.put(name.toLowerCase(), clazz);
}
/**
* 创建距离度量实例
* @param name 度量名称
* @return 距离度量实例
*/
public static DistanceMetric create(String name) {
return create(name, (Object[]) null);
}
/**
* 创建带参数的距离度量实例
* @param name 度量名称
* @param params 构造参数
* @return 距离度量实例
*/
public static DistanceMetric create(String name, Object... params) {
String lowerName = name.toLowerCase();
// 处理特殊情况
if (lowerName.equals("minkowski") && params != null && params.length > 0) {
double p = ((Number) params[0]).doubleValue();
returnnew MinkowskiDistance(p);
}
if (lowerName.equals("jaccard") && params != null && params.length > 0) {
float threshold = ((Number) params[0]).floatValue();
returnnew JaccardDistance(threshold);
}
// 标准创建流程
Class<? extends DistanceMetric> clazz = METRIC_REGISTRY.get(lowerName);
if (clazz == null) {
thrownew IllegalArgumentException("Unknown distance metric: " + name);
}
try {
return clazz.newInstance();
} catch (Exception e) {
thrownew RuntimeException("Failed to create distance metric: " + name, e);
}
}
/**
* 获取所有支持的距离度量名称
* @return 距离度量名称集合
*/
publicstatic java.util.Set<String> getSupportedMetrics() {
return METRIC_REGISTRY.keySet();
}
}
/**
* 距离度量性能测试
*/
publicclass DistanceMetricBenchmark {
@Test
public void benchmarkDistanceMetrics() {
int dimension = 512;
int iterations = 100000;
float[] vector1 = generateRandomVector(dimension);
float[] vector2 = generateRandomVector(dimension);
// 测试各种距离度量的性能
benchmarkMetric("euclidean", new EuclideanDistance(), vector1, vector2, iterations);
benchmarkMetric("cosine", new CosineDistance(), vector1, vector2, iterations);
benchmarkMetric("manhattan", new ManhattanDistance(), vector1, vector2, iterations);
benchmarkMetric("chebyshev", new ChebyshevDistance(), vector1, vector2, iterations);
}
private void benchmarkMetric(String name, DistanceMetric metric,
float[] v1, float[] v2, int iterations) {
// 预热
for (int i = 0; i < 1000; i++) {
metric.distance(v1, v2);
}
long startTime = System.nanoTime();
for (int i = 0; i < iterations; i++) {
metric.distance(v1, v2);
}
long endTime = System.nanoTime();
double avgTime = (endTime - startTime) / (double) iterations;
System.out.printf("%s: %.2f ns/op\n", name, avgTime);
}
privatefloat[] generateRandomVector(int dimension) {
Random random = new Random(42);
float[] vector = newfloat[dimension];
for (int i = 0; i < dimension; i++) {
vector[i] = random.nextFloat();
}
return vector;
}
}
/**
* 距离度量准确性测试
*/
publicclass DistanceMetricAccuracyTest {
@Test
public void testDistanceProperties() {
float[] zero = {0, 0, 0};
float[] a = {1, 2, 3};
float[] b = {4, 5, 6};
DistanceMetric metric = new EuclideanDistance();
// 测试非负性
assertTrue(metric.distance(a, b) >= 0);
// 测试同一性
assertEquals(0.0f, metric.distance(a, a), 1e-6);
// 测试对称性
assertEquals(metric.distance(a, b), metric.distance(b, a), 1e-6);
// 测试三角不等式
float[] c = {7, 8, 9};
float dab = metric.distance(a, b);
float dbc = metric.distance(b, c);
float dac = metric.distance(a, c);
assertTrue(dac <= dab + dbc + 1e-6); // 考虑浮点误差
}
@Test
public void testSpecialCases() {
float[] zero = {0, 0, 0};
float[] ones = {1, 1, 1};
float[] negative = {-1, -1, -1};
// 测试零向量
CosineDistance cosine = new CosineDistance();
assertEquals(1.0f, cosine.distance(zero, ones), 1e-6);
// 测试反向向量
assertEquals(2.0f, cosine.distance(ones, negative), 1e-6);
}
}
本章详细介绍了向量搜索引擎中各种距离度量算法的实现:
在下一章中,我们将深入学习HNSW算法的原理和实现,这是整个向量搜索引擎的核心。
思考题: