向量类(Vector)是整个搜索引擎的基础数据结构,需要满足以下要求:
public class Vector implements Serializable {
private static final long serialVersionUID = 1L;
private float[] data; // 向量数据
private int dimension; // 向量维度
private volatile Float norm; // 缓存的L2范数,使用volatile确保线程安全
}
package com.jvector.core;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;
/**
* 向量类,表示一个多维向量
*/
public class Vector implements Serializable {
private static final long serialVersionUID = 1L;
private float[] data;
private int dimension;
private volatile Float norm;
/**
* 无参构造函数,用于序列化
*/
public Vector() {
// 默认构造函数,用于Kryo序列化
}
/**
* 基于float数组构造向量
* @param data 向量数据,不能为null
*/
public Vector(float[] data) {
this.data = Objects.requireNonNull(data, "Vector data cannot be null");
this.dimension = data.length;
if (dimension == 0) {
throw new IllegalArgumentException("Vector dimension must be greater than 0");
}
}
/**
* 创建指定维度的零向量
* @param dimension 向量维度,必须大于0
*/
public Vector(int dimension) {
if (dimension <= 0) {
throw new IllegalArgumentException("Vector dimension must be greater than 0");
}
this.dimension = dimension;
this.data = new float[dimension];
}
}
/**
* 获取指定位置的向量值
* @param index 索引位置
* @return 向量值
* @throws IndexOutOfBoundsException 如果索引越界
*/
public float get(int index) {
if (index < 0 || index >= dimension) {
throw new IndexOutOfBoundsException("Index " + index + " out of bounds for dimension " + dimension);
}
return data[index];
}
/**
* 设置指定位置的向量值
* @param index 索引位置
* @param value 新值
* @throws IndexOutOfBoundsException 如果索引越界
*/
public void set(int index, float value) {
if (index < 0 || index >= dimension) {
throw new IndexOutOfBoundsException("Index " + index + " out of bounds for dimension " + dimension);
}
data[index] = value;
norm = null; // 重置缓存的范数
}
/**
* 获取向量维度
* @return 向量维度
*/
public int getDimension() {
return dimension;
}
/**
* 获取向量数据的副本(防止外部修改)
* @return 向量数据副本
*/
public float[] getData() {
return Arrays.copyOf(data, dimension);
}
/**
* 计算向量的L2范数(欧几里得长度)
* 使用延迟计算和缓存机制提高性能
* @return L2范数
*/
public float norm() {
if (norm == null) {
synchronized (this) {
if (norm == null) { // 双重检查锁定
float sum = 0;
for (float v : data) {
sum += v * v;
}
norm = (float) Math.sqrt(sum);
}
}
}
return norm;
}
/**
* 计算向量的L1范数(曼哈顿距离)
* @return L1范数
*/
public float norm1() {
float sum = 0;
for (float v : data) {
sum += Math.abs(v);
}
return sum;
}
/**
* 归一化向量(L2归一化)
* @return 新的归一化向量
*/
public Vector normalize() {
float n = norm();
if (n == 0) {
return new Vector(Arrays.copyOf(data, dimension));
}
float[] normalized = new float[dimension];
for (int i = 0; i < dimension; i++) {
normalized[i] = data[i] / n;
}
return new Vector(normalized);
}
/**
* 计算与另一个向量的点积
* @param other 另一个向量
* @return 点积结果
* @throws IllegalArgumentException 如果向量维度不匹配
*/
public float dot(Vector other) {
if (this.dimension != other.dimension) {
throw new IllegalArgumentException(
String.format("Vector dimensions must match: %d vs %d",
this.dimension, other.dimension));
}
float sum = 0;
for (int i = 0; i < dimension; i++) {
sum += this.data[i] * other.data[i];
}
return sum;
}
/**
* 向量加法
* @param other 另一个向量
* @return 相加后的新向量
*/
public Vector add(Vector other) {
if (this.dimension != other.dimension) {
throw new IllegalArgumentException("Vector dimensions must match");
}
float[] result = new float[dimension];
for (int i = 0; i < dimension; i++) {
result[i] = this.data[i] + other.data[i];
}
return new Vector(result);
}
/**
* 向量减法
* @param other 另一个向量
* @return 相减后的新向量
*/
public Vector subtract(Vector other) {
if (this.dimension != other.dimension) {
throw new IllegalArgumentException("Vector dimensions must match");
}
float[] result = new float[dimension];
for (int i = 0; i < dimension; i++) {
result[i] = this.data[i] - other.data[i];
}
return new Vector(result);
}
/**
* 向量标量乘法
* @param scalar 标量值
* @return 相乘后的新向量
*/
public Vector multiply(float scalar) {
float[] result = new float[dimension];
for (int i = 0; i < dimension; i++) {
result[i] = this.data[i] * scalar;
}
return new Vector(result);
}
SearchResult类用于封装搜索操作的结果,包含向量ID、距离值和可选的向量数据。
package com.jvector.core;
import java.io.Serializable;
import java.util.Objects;
/**
* 搜索结果类,封装搜索操作的结果
*/
public class SearchResult implements Serializable, Comparable<SearchResult> {
private static final long serialVersionUID = 1L;
private final long id; // 向量ID
private final float distance; // 距离值
private final Vector vector; // 向量数据(可选)
/**
* 构造搜索结果
* @param id 向量ID
* @param distance 距离值
*/
public SearchResult(long id, float distance) {
this(id, distance, null);
}
/**
* 构造包含向量数据的搜索结果
* @param id 向量ID
* @param distance 距离值
* @param vector 向量数据
*/
public SearchResult(long id, float distance, Vector vector) {
this.id = id;
this.distance = distance;
this.vector = vector;
}
// Getter methods
public long getId() {
return id;
}
public float getDistance() {
return distance;
}
public Vector getVector() {
return vector;
}
/**
* 检查是否包含向量数据
* @return 是否包含向量数据
*/
public boolean hasVector() {
return vector != null;
}
/**
* 比较方法,按距离升序排列
*/
@Override
public int compareTo(SearchResult other) {
return Float.compare(this.distance, other.distance);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SearchResult that = (SearchResult) o;
return id == that.id &&
Float.compare(that.distance, distance) == 0 &&
Objects.equals(vector, that.vector);
}
@Override
public int hashCode() {
return Objects.hash(id, distance, vector);
}
@Override
public String toString() {
return String.format("SearchResult{id=%d, distance=%.6f, hasVector=%b}",
id, distance, hasVector());
}
}
/**
* 搜索结果工具类
*/
public class SearchResults {
/**
* 按距离排序搜索结果
* @param results 搜索结果列表
* @param ascending 是否升序排列
*/
public static void sortByDistance(List<SearchResult> results, boolean ascending) {
if (ascending) {
results.sort(Comparator.comparingDouble(SearchResult::getDistance));
} else {
results.sort(Comparator.comparingDouble(SearchResult::getDistance).reversed());
}
}
/**
* 过滤距离小于阈值的结果
* @param results 搜索结果列表
* @param threshold 距离阈值
* @return 过滤后的结果
*/
public static List<SearchResult> filterByDistance(List<SearchResult> results, float threshold) {
return results.stream()
.filter(r -> r.getDistance() <= threshold)
.collect(Collectors.toList());
}
/**
* 提取所有向量ID
* @param results 搜索结果列表
* @return ID列表
*/
public static List<Long> extractIds(List<SearchResult> results) {
return results.stream()
.map(SearchResult::getId)
.collect(Collectors.toList());
}
}

为了减少对象创建和垃圾回收的开销,可以实现向量对象池:
/**
* 向量对象池,减少对象创建开销
*/
public class VectorPool {
private final Queue<Vector> pool = new ConcurrentLinkedQueue<>();
private final int maxPoolSize;
public VectorPool(int maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
/**
* 获取向量对象
* @param dimension 向量维度
* @return 向量对象
*/
public Vector acquire(int dimension) {
Vector vector = pool.poll();
if (vector == null || vector.getDimension() != dimension) {
return new Vector(dimension);
}
return vector;
}
/**
* 归还向量对象
* @param vector 向量对象
*/
public void release(Vector vector) {
if (pool.size() < maxPoolSize) {
// 清零向量数据
float[] data = vector.getData();
Arrays.fill(data, 0.0f);
pool.offer(vector);
}
}
}
对于高维稀疏向量,可以考虑使用压缩存储格式:
/**
* 稀疏向量实现
*/
public class SparseVector extends Vector {
private final Map<Integer, Float> sparseData;
private final int dimension;
public SparseVector(int dimension) {
this.dimension = dimension;
this.sparseData = new HashMap<>();
}
@Override
public float get(int index) {
return sparseData.getOrDefault(index, 0.0f);
}
@Override
public void set(int index, float value) {
if (value == 0.0f) {
sparseData.remove(index);
} else {
sparseData.put(index, value);
}
}
/**
* 获取非零元素数量
* @return 非零元素数量
*/
public int getNonZeroCount() {
return sparseData.size();
}
/**
* 计算稀疏度
* @return 稀疏度(0-1之间)
*/
public float getSparsity() {
return 1.0f - (float) getNonZeroCount() / dimension;
}
}
/**
* 向量验证工具类
*/
public class VectorValidator {
/**
* 验证向量数组是否有效
* @param data 向量数据
* @throws IllegalArgumentException 如果数据无效
*/
public static void validateVectorData(float[] data) {
Objects.requireNonNull(data, "Vector data cannot be null");
if (data.length == 0) {
throw new IllegalArgumentException("Vector dimension must be greater than 0");
}
// 检查是否包含NaN或无穷大
for (int i = 0; i < data.length; i++) {
if (!Float.isFinite(data[i])) {
throw new IllegalArgumentException(
String.format("Vector contains invalid value at index %d: %f", i, data[i]));
}
}
}
/**
* 验证向量维度是否匹配
* @param v1 向量1
* @param v2 向量2
* @throws IllegalArgumentException 如果维度不匹配
*/
public static void validateDimensionMatch(Vector v1, Vector v2) {
if (v1.getDimension() != v2.getDimension()) {
throw new IllegalArgumentException(
String.format("Vector dimensions must match: %d vs %d",
v1.getDimension(), v2.getDimension()));
}
}
/**
* 验证向量是否已归一化
* @param vector 向量
* @param tolerance 容忍误差
* @return 是否已归一化
*/
public static boolean isNormalized(Vector vector, float tolerance) {
float norm = vector.norm();
return Math.abs(norm - 1.0f) <= tolerance;
}
}
为了确保线程安全,可以设计不可变的向量类:
/**
* 不可变向量类
*/
public final class ImmutableVector extends Vector {
public ImmutableVector(float[] data) {
super(Arrays.copyOf(data, data.length)); // 深拷贝确保不可变性
}
@Override
public void set(int index, float value) {
throw new UnsupportedOperationException("ImmutableVector cannot be modified");
}
@Override
public float[] getData() {
// 返回副本而不是原始数组引用
return super.getData();
}
/**
* 创建修改后的新向量
* @param index 要修改的索引
* @param value 新值
* @return 新的不可变向量
*/
public ImmutableVector withValue(int index, float value) {
float[] newData = getData();
newData[index] = value;
return new ImmutableVector(newData);
}
}
/**
* 向量性能测试
*/
public class VectorBenchmark {
@Test
public void benchmarkVectorOperations() {
int dimension = 1024;
int iterations = 100000;
// 测试向量创建性能
long startTime = System.nanoTime();
for (int i = 0; i < iterations; i++) {
Vector v = new Vector(generateRandomVector(dimension));
}
long creationTime = System.nanoTime() - startTime;
// 测试点积计算性能
Vector v1 = new Vector(generateRandomVector(dimension));
Vector v2 = new Vector(generateRandomVector(dimension));
startTime = System.nanoTime();
for (int i = 0; i < iterations; i++) {
float result = v1.dot(v2);
}
long dotProductTime = System.nanoTime() - startTime;
// 测试范数计算性能
startTime = System.nanoTime();
for (int i = 0; i < iterations; i++) {
float norm = v1.norm();
}
long normTime = System.nanoTime() - startTime;
System.out.println("Vector creation: " + creationTime / 1000000 + " ms");
System.out.println("Dot product: " + dotProductTime / 1000000 + " ms");
System.out.println("Norm calculation: " + normTime / 1000000 + " ms");
}
@Test
public void benchmarkMemoryUsage() {
int vectorCount = 10000;
int dimension = 512;
// 使用MemoryMXBean监控内存使用
MemoryMXBean memoryBean = ManagementFactory.getMemoryMXBean();
long beforeMemory = memoryBean.getHeapMemoryUsage().getUsed();
List<Vector> vectors = new ArrayList<>();
for (int i = 0; i < vectorCount; i++) {
vectors.add(new Vector(generateRandomVector(dimension)));
}
long afterMemory = memoryBean.getHeapMemoryUsage().getUsed();
long memoryUsed = afterMemory - beforeMemory;
System.out.println("Memory used for " + vectorCount + " vectors: " +
memoryUsed / 1024 / 1024 + " MB");
System.out.println("Average memory per vector: " +
memoryUsed / vectorCount + " bytes");
}
private float[] generateRandomVector(int dimension) {
Random random = new Random();
float[] vector = new float[dimension];
for (int i = 0; i < dimension; i++) {
vector[i] = random.nextFloat();
}
return vector;
}
}
/**
* 自定义向量序列化器
*/
public class VectorSerializer {
/**
* 序列化向量到字节数组
* @param vector 向量对象
* @return 字节数组
*/
public static byte[] serialize(Vector vector) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
// 写入维度
dos.writeInt(vector.getDimension());
// 写入向量数据
float[] data = vector.getData();
for (float value : data) {
dos.writeFloat(value);
}
dos.close();
return baos.toByteArray();
}
/**
* 从字节数组反序列化向量
* @param bytes 字节数组
* @return 向量对象
*/
public static Vector deserialize(byte[] bytes) throws IOException {
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
DataInputStream dis = new DataInputStream(bais);
// 读取维度
int dimension = dis.readInt();
// 读取向量数据
float[] data = new float[dimension];
for (int i = 0; i < dimension; i++) {
data[i] = dis.readFloat();
}
dis.close();
return new Vector(data);
}
}
/**
* Kryo序列化器注册
*/
public class VectorKryoSerializer extends Serializer<Vector> {
@Override
public void write(Kryo kryo, Output output, Vector vector) {
output.writeInt(vector.getDimension());
float[] data = vector.getData();
for (float value : data) {
output.writeFloat(value);
}
}
@Override
public Vector read(Kryo kryo, Input input, Class<Vector> type) {
int dimension = input.readInt();
float[] data = new float[dimension];
for (int i = 0; i < dimension; i++) {
data[i] = input.readFloat();
}
return new Vector(data);
}
}
本章详细介绍了向量搜索引擎中核心数据结构的设计和实现:
在下一章中,我们将实现各种距离度量算法,这是向量搜索的核心计算组件。
思考题: