首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >第三章:向量数据结构的实现

第三章:向量数据结构的实现

作者头像
javpower
发布2025-07-26 11:11:56
发布2025-07-26 11:11:56
1700
举报

第三章:向量数据结构的实现

3.1 向量类的设计思路

3.1.1 设计目标

向量类(Vector)是整个搜索引擎的基础数据结构,需要满足以下要求:

  1. 高效存储:使用原生float数组存储向量数据
  2. 内存优化:支持缓存范数计算结果,避免重复计算
  3. 类型安全:提供类型安全的访问方法
  4. 序列化支持:支持序列化以便持久化存储
  5. 线程安全:部分操作需要考虑并发访问

3.1.2 核心属性设计

代码语言:javascript
复制
public class Vector implements Serializable {
    private static final long serialVersionUID = 1L;

    private float[] data;           // 向量数据
    private int dimension;          // 向量维度
    private volatile Float norm;    // 缓存的L2范数,使用volatile确保线程安全
}

3.1.3 设计模式运用

  • 值对象模式:Vector作为值对象,具有不可变性
  • 缓存模式:延迟计算并缓存范数值
  • 建造器模式:提供多种构造方式

3.2 Vector类完整实现

3.2.1 构造函数设计

代码语言:javascript
复制
package com.jvector.core;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;

/**
 * 向量类,表示一个多维向量
 */
public class Vector implements Serializable {
    private static final long serialVersionUID = 1L;

    private float[] data;
    private int dimension;
    private volatile Float norm;

    /**
     * 无参构造函数,用于序列化
     */
    public Vector() {
        // 默认构造函数,用于Kryo序列化
    }

    /**
     * 基于float数组构造向量
     * @param data 向量数据,不能为null
     */
    public Vector(float[] data) {
        this.data = Objects.requireNonNull(data, "Vector data cannot be null");
        this.dimension = data.length;
        if (dimension == 0) {
            throw new IllegalArgumentException("Vector dimension must be greater than 0");
        }
    }

    /**
     * 创建指定维度的零向量
     * @param dimension 向量维度,必须大于0
     */
    public Vector(int dimension) {
        if (dimension <= 0) {
            throw new IllegalArgumentException("Vector dimension must be greater than 0");
        }
        this.dimension = dimension;
        this.data = new float[dimension];
    }
}

3.2.2 数据访问方法

代码语言:javascript
复制
/**
 * 获取指定位置的向量值
 * @param index 索引位置
 * @return 向量值
 * @throws IndexOutOfBoundsException 如果索引越界
 */
public float get(int index) {
    if (index < 0 || index >= dimension) {
        throw new IndexOutOfBoundsException("Index " + index + " out of bounds for dimension " + dimension);
    }
    return data[index];
}

/**
 * 设置指定位置的向量值
 * @param index 索引位置
 * @param value 新值
 * @throws IndexOutOfBoundsException 如果索引越界
 */
public void set(int index, float value) {
    if (index < 0 || index >= dimension) {
        throw new IndexOutOfBoundsException("Index " + index + " out of bounds for dimension " + dimension);
    }
    data[index] = value;
    norm = null; // 重置缓存的范数
}

/**
 * 获取向量维度
 * @return 向量维度
 */
public int getDimension() {
    return dimension;
}

/**
 * 获取向量数据的副本(防止外部修改)
 * @return 向量数据副本
 */
public float[] getData() {
    return Arrays.copyOf(data, dimension);
}

3.2.3 数学运算方法

代码语言:javascript
复制
/**
 * 计算向量的L2范数(欧几里得长度)
 * 使用延迟计算和缓存机制提高性能
 * @return L2范数
 */
public float norm() {
    if (norm == null) {
        synchronized (this) {
            if (norm == null) { // 双重检查锁定
                float sum = 0;
                for (float v : data) {
                    sum += v * v;
                }
                norm = (float) Math.sqrt(sum);
            }
        }
    }
    return norm;
}

/**
 * 计算向量的L1范数(曼哈顿距离)
 * @return L1范数
 */
public float norm1() {
    float sum = 0;
    for (float v : data) {
        sum += Math.abs(v);
    }
    return sum;
}

/**
 * 归一化向量(L2归一化)
 * @return 新的归一化向量
 */
public Vector normalize() {
    float n = norm();
    if (n == 0) {
        return new Vector(Arrays.copyOf(data, dimension));
    }
    float[] normalized = new float[dimension];
    for (int i = 0; i < dimension; i++) {
        normalized[i] = data[i] / n;
    }
    return new Vector(normalized);
}

/**
 * 计算与另一个向量的点积
 * @param other 另一个向量
 * @return 点积结果
 * @throws IllegalArgumentException 如果向量维度不匹配
 */
public float dot(Vector other) {
    if (this.dimension != other.dimension) {
        throw new IllegalArgumentException(
            String.format("Vector dimensions must match: %d vs %d", 
                         this.dimension, other.dimension));
    }
    float sum = 0;
    for (int i = 0; i < dimension; i++) {
        sum += this.data[i] * other.data[i];
    }
    return sum;
}

/**
 * 向量加法
 * @param other 另一个向量
 * @return 相加后的新向量
 */
public Vector add(Vector other) {
    if (this.dimension != other.dimension) {
        throw new IllegalArgumentException("Vector dimensions must match");
    }
    float[] result = new float[dimension];
    for (int i = 0; i < dimension; i++) {
        result[i] = this.data[i] + other.data[i];
    }
    return new Vector(result);
}

/**
 * 向量减法
 * @param other 另一个向量
 * @return 相减后的新向量
 */
public Vector subtract(Vector other) {
    if (this.dimension != other.dimension) {
        throw new IllegalArgumentException("Vector dimensions must match");
    }
    float[] result = new float[dimension];
    for (int i = 0; i < dimension; i++) {
        result[i] = this.data[i] - other.data[i];
    }
    return new Vector(result);
}

/**
 * 向量标量乘法
 * @param scalar 标量值
 * @return 相乘后的新向量
 */
public Vector multiply(float scalar) {
    float[] result = new float[dimension];
    for (int i = 0; i < dimension; i++) {
        result[i] = this.data[i] * scalar;
    }
    return new Vector(result);
}

3.3 SearchResult类的实现

3.3.1 搜索结果设计

SearchResult类用于封装搜索操作的结果,包含向量ID、距离值和可选的向量数据。

代码语言:javascript
复制
package com.jvector.core;

import java.io.Serializable;
import java.util.Objects;

/**
 * 搜索结果类,封装搜索操作的结果
 */
public class SearchResult implements Serializable, Comparable<SearchResult> {
    private static final long serialVersionUID = 1L;

    private final long id;           // 向量ID
    private final float distance;    // 距离值
    private final Vector vector;     // 向量数据(可选)

    /**
     * 构造搜索结果
     * @param id 向量ID
     * @param distance 距离值
     */
    public SearchResult(long id, float distance) {
        this(id, distance, null);
    }

    /**
     * 构造包含向量数据的搜索结果
     * @param id 向量ID
     * @param distance 距离值
     * @param vector 向量数据
     */
    public SearchResult(long id, float distance, Vector vector) {
        this.id = id;
        this.distance = distance;
        this.vector = vector;
    }

    // Getter methods
    public long getId() {
        return id;
    }

    public float getDistance() {
        return distance;
    }

    public Vector getVector() {
        return vector;
    }

    /**
     * 检查是否包含向量数据
     * @return 是否包含向量数据
     */
    public boolean hasVector() {
        return vector != null;
    }

    /**
     * 比较方法,按距离升序排列
     */
    @Override
    public int compareTo(SearchResult other) {
        return Float.compare(this.distance, other.distance);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        SearchResult that = (SearchResult) o;
        return id == that.id &&
               Float.compare(that.distance, distance) == 0 &&
               Objects.equals(vector, that.vector);
    }

    @Override
    public int hashCode() {
        return Objects.hash(id, distance, vector);
    }

    @Override
    public String toString() {
        return String.format("SearchResult{id=%d, distance=%.6f, hasVector=%b}", 
                           id, distance, hasVector());
    }
}

3.3.2 搜索结果工具方法

代码语言:javascript
复制
/**
 * 搜索结果工具类
 */
public class SearchResults {

    /**
     * 按距离排序搜索结果
     * @param results 搜索结果列表
     * @param ascending 是否升序排列
     */
    public static void sortByDistance(List<SearchResult> results, boolean ascending) {
        if (ascending) {
            results.sort(Comparator.comparingDouble(SearchResult::getDistance));
        } else {
            results.sort(Comparator.comparingDouble(SearchResult::getDistance).reversed());
        }
    }

    /**
     * 过滤距离小于阈值的结果
     * @param results 搜索结果列表
     * @param threshold 距离阈值
     * @return 过滤后的结果
     */
    public static List<SearchResult> filterByDistance(List<SearchResult> results, float threshold) {
        return results.stream()
                     .filter(r -> r.getDistance() <= threshold)
                     .collect(Collectors.toList());
    }

    /**
     * 提取所有向量ID
     * @param results 搜索结果列表
     * @return ID列表
     */
    public static List<Long> extractIds(List<SearchResult> results) {
        return results.stream()
                     .map(SearchResult::getId)
                     .collect(Collectors.toList());
    }
}

3.4 向量数据的内存优化

3.4.1 内存布局优化

3.4.2 向量池设计

为了减少对象创建和垃圾回收的开销,可以实现向量对象池:

代码语言:javascript
复制
/**
 * 向量对象池,减少对象创建开销
 */
public class VectorPool {
    private final Queue<Vector> pool = new ConcurrentLinkedQueue<>();
    private final int maxPoolSize;

    public VectorPool(int maxPoolSize) {
        this.maxPoolSize = maxPoolSize;
    }

    /**
     * 获取向量对象
     * @param dimension 向量维度
     * @return 向量对象
     */
    public Vector acquire(int dimension) {
        Vector vector = pool.poll();
        if (vector == null || vector.getDimension() != dimension) {
            return new Vector(dimension);
        }
        return vector;
    }

    /**
     * 归还向量对象
     * @param vector 向量对象
     */
    public void release(Vector vector) {
        if (pool.size() < maxPoolSize) {
            // 清零向量数据
            float[] data = vector.getData();
            Arrays.fill(data, 0.0f);
            pool.offer(vector);
        }
    }
}

3.4.3 压缩存储

对于高维稀疏向量,可以考虑使用压缩存储格式:

代码语言:javascript
复制
/**
 * 稀疏向量实现
 */
public class SparseVector extends Vector {
    private final Map<Integer, Float> sparseData;
    private final int dimension;

    public SparseVector(int dimension) {
        this.dimension = dimension;
        this.sparseData = new HashMap<>();
    }

    @Override
    public float get(int index) {
        return sparseData.getOrDefault(index, 0.0f);
    }

    @Override
    public void set(int index, float value) {
        if (value == 0.0f) {
            sparseData.remove(index);
        } else {
            sparseData.put(index, value);
        }
    }

    /**
     * 获取非零元素数量
     * @return 非零元素数量
     */
    public int getNonZeroCount() {
        return sparseData.size();
    }

    /**
     * 计算稀疏度
     * @return 稀疏度(0-1之间)
     */
    public float getSparsity() {
        return 1.0f - (float) getNonZeroCount() / dimension;
    }
}

3.5 向量验证和安全性

3.5.1 输入验证

代码语言:javascript
复制
/**
 * 向量验证工具类
 */
public class VectorValidator {

    /**
     * 验证向量数组是否有效
     * @param data 向量数据
     * @throws IllegalArgumentException 如果数据无效
     */
    public static void validateVectorData(float[] data) {
        Objects.requireNonNull(data, "Vector data cannot be null");

        if (data.length == 0) {
            throw new IllegalArgumentException("Vector dimension must be greater than 0");
        }

        // 检查是否包含NaN或无穷大
        for (int i = 0; i < data.length; i++) {
            if (!Float.isFinite(data[i])) {
                throw new IllegalArgumentException(
                    String.format("Vector contains invalid value at index %d: %f", i, data[i]));
            }
        }
    }

    /**
     * 验证向量维度是否匹配
     * @param v1 向量1
     * @param v2 向量2
     * @throws IllegalArgumentException 如果维度不匹配
     */
    public static void validateDimensionMatch(Vector v1, Vector v2) {
        if (v1.getDimension() != v2.getDimension()) {
            throw new IllegalArgumentException(
                String.format("Vector dimensions must match: %d vs %d", 
                             v1.getDimension(), v2.getDimension()));
        }
    }

    /**
     * 验证向量是否已归一化
     * @param vector 向量
     * @param tolerance 容忍误差
     * @return 是否已归一化
     */
    public static boolean isNormalized(Vector vector, float tolerance) {
        float norm = vector.norm();
        return Math.abs(norm - 1.0f) <= tolerance;
    }
}

3.5.2 不可变向量设计

为了确保线程安全,可以设计不可变的向量类:

代码语言:javascript
复制
/**
 * 不可变向量类
 */
public final class ImmutableVector extends Vector {

    public ImmutableVector(float[] data) {
        super(Arrays.copyOf(data, data.length)); // 深拷贝确保不可变性
    }

    @Override
    public void set(int index, float value) {
        throw new UnsupportedOperationException("ImmutableVector cannot be modified");
    }

    @Override
    public float[] getData() {
        // 返回副本而不是原始数组引用
        return super.getData();
    }

    /**
     * 创建修改后的新向量
     * @param index 要修改的索引
     * @param value 新值
     * @return 新的不可变向量
     */
    public ImmutableVector withValue(int index, float value) {
        float[] newData = getData();
        newData[index] = value;
        return new ImmutableVector(newData);
    }
}

3.6 性能测试和基准测试

3.6.1 向量操作基准测试

代码语言:javascript
复制
/**
 * 向量性能测试
 */
public class VectorBenchmark {

    @Test
    public void benchmarkVectorOperations() {
        int dimension = 1024;
        int iterations = 100000;

        // 测试向量创建性能
        long startTime = System.nanoTime();
        for (int i = 0; i < iterations; i++) {
            Vector v = new Vector(generateRandomVector(dimension));
        }
        long creationTime = System.nanoTime() - startTime;

        // 测试点积计算性能
        Vector v1 = new Vector(generateRandomVector(dimension));
        Vector v2 = new Vector(generateRandomVector(dimension));

        startTime = System.nanoTime();
        for (int i = 0; i < iterations; i++) {
            float result = v1.dot(v2);
        }
        long dotProductTime = System.nanoTime() - startTime;

        // 测试范数计算性能
        startTime = System.nanoTime();
        for (int i = 0; i < iterations; i++) {
            float norm = v1.norm();
        }
        long normTime = System.nanoTime() - startTime;

        System.out.println("Vector creation: " + creationTime / 1000000 + " ms");
        System.out.println("Dot product: " + dotProductTime / 1000000 + " ms");
        System.out.println("Norm calculation: " + normTime / 1000000 + " ms");
    }

    @Test
    public void benchmarkMemoryUsage() {
        int vectorCount = 10000;
        int dimension = 512;

        // 使用MemoryMXBean监控内存使用
        MemoryMXBean memoryBean = ManagementFactory.getMemoryMXBean();

        long beforeMemory = memoryBean.getHeapMemoryUsage().getUsed();

        List<Vector> vectors = new ArrayList<>();
        for (int i = 0; i < vectorCount; i++) {
            vectors.add(new Vector(generateRandomVector(dimension)));
        }

        long afterMemory = memoryBean.getHeapMemoryUsage().getUsed();
        long memoryUsed = afterMemory - beforeMemory;

        System.out.println("Memory used for " + vectorCount + " vectors: " + 
                          memoryUsed / 1024 / 1024 + " MB");
        System.out.println("Average memory per vector: " + 
                          memoryUsed / vectorCount + " bytes");
    }

    private float[] generateRandomVector(int dimension) {
        Random random = new Random();
        float[] vector = new float[dimension];
        for (int i = 0; i < dimension; i++) {
            vector[i] = random.nextFloat();
        }
        return vector;
    }
}

3.7 向量序列化和反序列化

3.7.1 自定义序列化

代码语言:javascript
复制
/**
 * 自定义向量序列化器
 */
public class VectorSerializer {

    /**
     * 序列化向量到字节数组
     * @param vector 向量对象
     * @return 字节数组
     */
    public static byte[] serialize(Vector vector) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        DataOutputStream dos = new DataOutputStream(baos);

        // 写入维度
        dos.writeInt(vector.getDimension());

        // 写入向量数据
        float[] data = vector.getData();
        for (float value : data) {
            dos.writeFloat(value);
        }

        dos.close();
        return baos.toByteArray();
    }

    /**
     * 从字节数组反序列化向量
     * @param bytes 字节数组
     * @return 向量对象
     */
    public static Vector deserialize(byte[] bytes) throws IOException {
        ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
        DataInputStream dis = new DataInputStream(bais);

        // 读取维度
        int dimension = dis.readInt();

        // 读取向量数据
        float[] data = new float[dimension];
        for (int i = 0; i < dimension; i++) {
            data[i] = dis.readFloat();
        }

        dis.close();
        return new Vector(data);
    }
}

3.7.2 Kryo序列化支持

代码语言:javascript
复制
/**
 * Kryo序列化器注册
 */
public class VectorKryoSerializer extends Serializer<Vector> {

    @Override
    public void write(Kryo kryo, Output output, Vector vector) {
        output.writeInt(vector.getDimension());
        float[] data = vector.getData();
        for (float value : data) {
            output.writeFloat(value);
        }
    }

    @Override
    public Vector read(Kryo kryo, Input input, Class<Vector> type) {
        int dimension = input.readInt();
        float[] data = new float[dimension];
        for (int i = 0; i < dimension; i++) {
            data[i] = input.readFloat();
        }
        return new Vector(data);
    }
}

小结

本章详细介绍了向量搜索引擎中核心数据结构的设计和实现:

  1. Vector类
    • 支持多种构造方式
    • 提供完整的数学运算方法
    • 使用缓存机制优化性能
    • 保证线程安全
  2. SearchResult类
    • 封装搜索结果
    • 支持排序和比较
    • 提供工具方法
  3. 性能优化
    • 对象池减少创建开销
    • 稀疏向量支持
    • 内存布局优化
  4. 安全性
    • 输入验证
    • 不可变设计
    • 异常处理

在下一章中,我们将实现各种距离度量算法,这是向量搜索的核心计算组件。

思考题:

  1. 为什么要使用volatile关键字修饰norm字段?
  2. 稀疏向量在什么场景下比密集向量更有优势?
  3. 如何设计一个支持动态维度的向量类?
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-07-23,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Coder建设 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 第三章:向量数据结构的实现
    • 3.1 向量类的设计思路
      • 3.1.1 设计目标
      • 3.1.2 核心属性设计
      • 3.1.3 设计模式运用
    • 3.2 Vector类完整实现
      • 3.2.1 构造函数设计
      • 3.2.2 数据访问方法
      • 3.2.3 数学运算方法
    • 3.3 SearchResult类的实现
      • 3.3.1 搜索结果设计
      • 3.3.2 搜索结果工具方法
    • 3.4 向量数据的内存优化
      • 3.4.1 内存布局优化
      • 3.4.2 向量池设计
      • 3.4.3 压缩存储
    • 3.5 向量验证和安全性
      • 3.5.1 输入验证
      • 3.5.2 不可变向量设计
    • 3.6 性能测试和基准测试
      • 3.6.1 向量操作基准测试
    • 3.7 向量序列化和反序列化
      • 3.7.1 自定义序列化
      • 3.7.2 Kryo序列化支持
    • 小结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档