
作者: HOS(安全风信子) 日期: 2026-05-24 主要来源平台: GitHub 摘要: 记忆是 AI IDE 持续发挥作用的关键基础设施。一个完整的 Memory System 需要支持三种截然不同但协同运作的记忆层次:短期记忆承载当前会话的工作上下文与即时状态,长期记忆通过向量嵌入实现跨会话的语义检索,永久记忆则构建结构化的项目知识图谱与文档体系。本文深入剖析三种记忆的设计原理、数据结构、存储策略与检索机制,详细讲解混合搜索(向量 + BM25 + 图遍历)的实现路径,阐述基于重要性评分与摘要生成的记忆压缩算法,并从敏感信息识别、访问控制、加密存储三个维度构建隐私安全保障体系。最后,通过完整的 Python/TypeScript 代码实现展示三层记忆架构的工程落地,为构建企业级 AI IDE 记忆系统提供可直接引用的架构方案与核心源码。
本节为你提供的核心技术价值:理解记忆系统对于 AI IDE 持续智能化运作的必要性,以及三层记忆架构的整体设计哲学。
在传统软件开发场景中,程序员的上下文切换成本极高——切换项目、切换任务、隔天继续工作时,大量相关上下文已经丢失。AI IDE 的核心竞争力之一,就是通过构建完整的记忆系统,让 AI 助手能够在任意时间点准确恢复工作上下文,持续学习项目知识,并在跨会话、跨项目中积累智慧。
一个设计良好的 Memory System 需要回答以下核心问题:
本文将逐一深入解答这些问题。首先从三层记忆架构的整体设计开始。
本节为你提供的核心技术价值:掌握短期、长期、永久记忆的层次关系、数据特征与协作模式,理解"记忆金字塔"的设计哲学。
三层记忆架构的设计灵感来源于认知科学中的人类记忆模型1。人类记忆系统通常分为感觉记忆、短时记忆(工作记忆)和长时记忆三个层次,各层次在容量、持续时间、编码方式上存在本质差异。AI IDE 的 Memory System 同样遵循这一规律,但针对软件开发的领域特征进行了适配。

特征维度 | 短期记忆 | 长期记忆 | 永久记忆 |
|---|---|---|---|
数据形态 | 结构化状态对象 | 浮点向量 + 元数据 | 图结构 + 文档 |
存储位置 | 内存 (Redis/In-process) | 向量数据库 (Pinecone/Milvus) | 图数据库 (Neo4j) + 文件系统 |
容量 | ~100KB-10MB | ~1GB-100GB | 无硬性限制 |
持续时间 | 当前会话 | 30-90天(可配置) | 永久 |
访问延迟 | <1ms | 10-100ms | 50-500ms |
编码方式 | JSON/结构化对象 | 嵌入向量 (1536-4096维) | RDF/属性图 |
检索方式 | 精确键查找 | 近似最近邻 (ANN) | 图遍历 + 索引 |
遗忘策略 | 会话结束时清除 | 基于访问频率衰减 | 仅手动删除 |
三层记忆之间存在明确的流动方向:
写入路径(Information Flow):
读取路径(Retrieval Flow):
三层记忆架构的设计遵循以下核心原则:
本节为你提供的核心技术价值:掌握短期记忆的数据结构设计、状态快照机制、会话管理策略,以及基于工作上下文窗口的 AI 响应增强方法。
短期记忆(Working Memory)是 AI IDE 感知当前工作状态的"感官系统"。它的核心职责包括:
短期记忆的生命周期与用户会话强绑定。会话开始时创建,会话结束时(通常由用户显式关闭或超时触发)销毁。这种设计确保了用户的隐私——未持久化的操作不会留下痕迹。
短期记忆的核心数据结构是 WorkingContext,它是一个复合对象,包含了当前工作状态的所有维度:
# short_term_memory/working_context.py
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from datetime import datetime
from enum import Enum
import uuid
class ContextLevel(Enum):
"""上下文层级,表示信息的重要性与稳定性"""
EPHEMERAL = "ephemeral" # 临时:光标位置、选中文本
OPERATIONAL = "operational" # 操作级:当前文件、编辑历史
CONCEPTUAL = "conceptual" # 概念级:打开的函数、数据结构
PROJECTIONAL = "projectional" # 项目级:项目结构、工作区配置
@dataclass
class FileReference:
"""文件引用,包含文件的完整上下文"""
file_id: str = field(default_factory=lambda: str(uuid.uuid4()))
file_path: str = ""
language: str = ""
size_bytes: int = 0
last_modified: datetime = field(default_factory=datetime.now)
# 语法结构
top_level_definitions: List[Dict[str, Any]] = field(default_factory=list)
imports: List[str] = field(default_factory=list)
# 编辑状态
is_dirty: bool = False
recent_edits: List[Dict[str, Any]] = field(default_factory=list)
# 光标与选择
cursor_position: Optional[Dict[str, int]] = None
selection_range: Optional[Dict[str, int]] = None
# 关联文件(同一模块、测试文件、类型定义等)
related_files: List[str] = field(default_factory=list)
@dataclass
class ConversationTurn:
"""对话轮次,记录一次用户交互"""
turn_id: str = field(default_factory=lambda: str(uuid.uuid4()))
timestamp: datetime = field(default_factory=datetime.now)
# 用户输入
user_message: str = ""
user_intent: Optional[str] = None # 解析出的意图
# AI 响应
ai_response: str = ""
ai_actions: List[Dict[str, Any]] = field(default_factory=list) # 执行的动作
# 上下文依赖
referenced_files: List[str] = field(default_factory=list)
referenced_symbols: List[str] = field(default_factory=list)
# 反馈
user_feedback: Optional[str] = None
is_positive_feedback: Optional[bool] = None
@dataclass
class TaskState:
"""任务状态,描述当前正在执行的任务"""
task_id: str = field(default_factory=lambda: str(uuid.uuid4()))
task_type: str = "" # "edit", "read", "search", "debug", "refactor"
description: str = ""
# 任务分解
sub_tasks: List[Dict[str, Any]] = field(default_factory=list)
current_sub_task_index: int = 0
# 进度
progress: float = 0.0 # 0.0 - 1.0
blockers: List[str] = field(default_factory=list)
# 关联资源
primary_file: Optional[str] = None
secondary_files: List[str] = field(default_factory=list)
@dataclass
class WorkingContext:
"""工作上下文容器,短期记忆的核心数据结构"""
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
created_at: datetime = field(default_factory=datetime.now)
last_accessed: datetime = field(default_factory=datetime.now)
# 基础信息
project_root: Optional[str] = None
opened_files: Dict[str, FileReference] = field(default_factory=dict)
active_file: Optional[str] = None
# 用户信息
conversation_history: List[ConversationTurn] = field(default_factory=list)
current_task: Optional[TaskState] = None
# 用户操作流(最近 N 次操作)
operation_stream: List[Dict[str, Any]] = field(default_factory=list)
max_stream_size: int = 100
# LLM 上下文窗口管理
context_tokens_used: int = 0
context_token_limit: int = 128000
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
def add_operation(self, operation: Dict[str, Any]) -> None:
"""添加操作到操作流"""
operation["timestamp"] = datetime.now().isoformat()
self.operation_stream.append(operation)
if len(self.operation_stream) > self.max_stream_size:
self.operation_stream.pop(0)
self.last_accessed = datetime.now()
def add_conversation_turn(self, turn: ConversationTurn) -> None:
"""添加对话轮次"""
self.conversation_history.append(turn)
self.last_accessed = datetime.now()
def get_context_summary(self) -> str:
"""生成上下文摘要,用于快速判断当前状态"""
summary_parts = []
if self.active_file:
summary_parts.append(f"活跃文件: {self.active_file}")
if self.current_task:
summary_parts.append(f"当前任务: {self.current_task.description}")
summary_parts.append(f"任务进度: {self.current_task.progress * 100:.0f}%")
opened_count = len(self.opened_files)
summary_parts.append(f"打开文件数: {opened_count}")
conversation_count = len(self.conversation_history)
summary_parts.append(f"对话轮次: {conversation_count}")
return " | ".join(summary_parts)状态快照(State Snapshot)是短期记忆的核心能力之一。它允许在任意时间点保存完整的工作状态,并在需要时恢复。快照机制对于以下场景至关重要:
# short_term_memory/snapshot.py
import json
import hashlib
from typing import Optional, List, Callable, Dict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import pickle
import gzip
class SnapshotType(Enum):
"""快照类型"""
FULL = "full" # 完整快照,包含所有数据
INCREMENTAL = "incremental" # 增量快照,仅包含变化
MINIMAL = "minimal" # 最小快照,仅包含关键状态
@dataclass
class Snapshot:
"""状态快照"""
snapshot_id: str = ""
parent_snapshot_id: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
snapshot_type: SnapshotType = SnapshotType.FULL
# 快照内容(压缩后)
content: bytes = b""
# 校验
checksum: str = ""
# 元数据
description: str = ""
tags: List[str] = field(default_factory=list)
size_bytes: int = 0
class SnapshotManager:
"""快照管理器"""
def __init__(
self,
storage_backend: Optional[Callable] = None,
max_snapshots: int = 50,
compression_enabled: bool = True
):
self.storage_backend = storage_backend
self.max_snapshots = max_snapshots
self.compression_enabled = compression_enabled
# 内存缓存
self._snapshots: Dict[str, Snapshot] = {}
self._snapshots_by_context: Dict[str, List[str]] = {} # context_id -> snapshot_ids
def create_snapshot(
self,
context: WorkingContext,
snapshot_type: SnapshotType = SnapshotType.FULL,
description: str = "",
tags: Optional[List[str]] = None
) -> Snapshot:
"""创建快照"""
# 序列化上下文
context_data = self._serialize_context(context, snapshot_type)
# 压缩
if self.compression_enabled:
content = gzip.compress(context_data)
else:
content = context_data
# 计算校验和
checksum = hashlib.sha256(content).hexdigest()
# 创建快照对象
snapshot = Snapshot(
snapshot_id=self._generate_snapshot_id(context.session_id),
snapshot_type=snapshot_type,
content=content,
checksum=checksum,
description=description,
tags=tags or [],
size_bytes=len(content)
)
# 存储
self._snapshots[snapshot.snapshot_id] = snapshot
if context.session_id not in self._snapshots_by_context:
self._snapshots_by_context[context.session_id] = []
self._snapshots_by_context[context.session_id].append(snapshot.snapshot_id)
# 持久化到后端
if self.storage_backend:
self.storage_backend.save(snapshot)
# 清理旧快照
self._prune_old_snapshots(context.session_id)
return snapshot
def restore_snapshot(self, snapshot_id: str) -> WorkingContext:
"""恢复快照"""
# 从缓存或后端获取
snapshot = self._snapshots.get(snapshot_id)
if not snapshot and self.storage_backend:
snapshot = self.storage_backend.load(snapshot_id)
if not snapshot:
raise ValueError(f"Snapshot not found: {snapshot_id}")
# 校验
if snapshot.checksum != hashlib.sha256(snapshot.content).hexdigest():
raise ValueError(f"Snapshot checksum mismatch: {snapshot_id}")
# 解压
if self.compression_enabled:
context_data = gzip.decompress(snapshot.content)
else:
context_data = snapshot.content
# 反序列化
return self._deserialize_context(context_data)
def _serialize_context(
self,
context: WorkingContext,
snapshot_type: SnapshotType
) -> bytes:
"""序列化上下文"""
if snapshot_type == SnapshotType.MINIMAL:
# 仅保留最小状态
minimal_data = {
"session_id": context.session_id,
"project_root": context.project_root,
"active_file": context.active_file,
"current_task": {
"task_id": context.current_task.task_id if context.current_task else None,
"description": context.current_task.description if context.current_task else None,
"progress": context.current_task.progress if context.current_task else 0
},
"conversation_count": len(context.conversation_history)
}
return json.dumps(minimal_data).encode()
else:
# 完整或增量序列化
return pickle.dumps(context)
def _deserialize_context(self, data: bytes) -> WorkingContext:
"""反序列化上下文"""
try:
return pickle.loads(data)
except Exception:
# 尝试 JSON 反序列化(minimal 类型)
return pickle.loads(data)
def _generate_snapshot_id(self, session_id: str) -> str:
"""生成快照 ID"""
timestamp = datetime.now().isoformat()
raw = f"{session_id}_{timestamp}"
return hashlib.md5(raw.encode()).hexdigest()[:16]
def _prune_old_snapshots(self, session_id: str) -> None:
"""清理旧快照"""
if session_id not in self._snapshots_by_context:
return
snapshot_ids = self._snapshots_by_context[session_id]
if len(snapshot_ids) <= self.max_snapshots:
return
# 删除最老的快照
to_delete = snapshot_ids[:-self.max_snapshots]
for snapshot_id in to_delete:
if snapshot_id in self._snapshots:
del self._snapshots[snapshot_id]
self._snapshots_by_context[session_id] = snapshot_ids[-self.max_snapshots:]AI IDE 的短期记忆面临一个关键挑战:上下文窗口是有限的。即使是 GPT-4 128K 上下文,在处理大型项目时也可能捉襟见肘。因此,上下文窗口管理成为短期记忆设计的核心议题。
# short_term_memory/context_window.py
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
import tiktoken
@dataclass
class ContextWindowConfig:
"""上下文窗口配置"""
model_name: str = "gpt-4"
max_tokens: int = 128000
reserved_tokens: int = 8000 # 保留空间,用于系统消息和响应
compression_threshold: float = 0.85 # 达到 85% 时触发压缩
class ContextWindowManager:
"""上下文窗口管理器,负责追踪和优化 token 使用"""
def __init__(self, config: ContextWindowConfig):
self.config = config
self.encoding = tiktoken.encoding_for_model(config.model_name)
self._current_tokens: int = 0
def count_tokens(self, text: str) -> int:
"""计算文本的 token 数"""
return len(self.encoding.encode(text))
def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int:
"""计算消息列表的总 token 数(使用 ChatML 格式)"""
# 估算:每个消息有额外的格式 token
total = 0
for msg in messages:
total += self.count_tokens(msg.get("content", ""))
total += 4 # 格式开销:role, name, content, separator
return total
def get_available_tokens(self) -> int:
"""获取可用 token 数"""
return self.config.max_tokens - self.config.reserved_tokens - self._current_tokens
def get_usage_ratio(self) -> float:
"""获取使用率"""
available = self.config.max_tokens - self.config.reserved_tokens
return self._current_tokens / available
def update_tokens(self, delta: int) -> None:
"""更新当前 token 计数"""
self._current_tokens += delta
def needs_compression(self) -> bool:
"""检查是否需要压缩"""
return self.get_usage_ratio() >= self.config.compression_threshold
def select_relevant_messages(
self,
messages: List[Dict[str, str]],
max_messages: Optional[int] = None
) -> Tuple[List[Dict[str, str]], int]:
"""
选择最相关的消息子集
使用策略:优先保留最近的对话,保留包含文件引用的消息
"""
if max_messages is None:
max_messages = len(messages)
# 按相关性评分
scored_messages = []
for i, msg in enumerate(messages):
score = 0
# 位置权重:越近越高
score += (i / len(messages)) * 0.3
# 文件引用权重
if "file_path" in msg.get("metadata", {}):
score += 0.4
# 工具调用权重(通常包含重要上下文)
if msg.get("role") == "assistant" and "tool_calls" in msg:
score += 0.3
scored_messages.append((score, i, msg))
# 排序并选择
scored_messages.sort(key=lambda x: x[0], reverse=True)
selected = scored_messages[:max_messages]
# 按原始顺序重排
selected.sort(key=lambda x: x[1])
selected_messages = [msg for _, _, msg in selected]
# 计算节省的 token
original_tokens = self.count_messages_tokens(messages)
new_tokens = self.count_messages_tokens(selected_messages)
saved_tokens = original_tokens - new_tokens
return selected_messages, saved_tokens
def summarize_and_compress(
self,
messages: List[Dict[str, str]],
summary_model: str = "gpt-3.5-turbo"
) -> Tuple[List[Dict[str, str]], str]:
"""
压缩历史消息:通过摘要保留核心信息
返回压缩后的消息列表和生成的摘要
"""
if len(messages) <= 4:
# 太短,不需要压缩
return messages, ""
# 分离可压缩部分(早期对话)和保留部分(近期对话)
preserve_count = 2 # 保留最近 2 条
compressable = messages[:-preserve_count]
preserved = messages[-preserve_count:]
# 生成摘要
summary_prompt = self._build_summary_prompt(compressable)
# 实际项目中,这里会调用 LLM 生成摘要
# summary = call_llm(summary_prompt, model=summary_model)
summary = "[早期对话摘要占位符]"
# 构建压缩后的消息
compressed = [
{"role": "system", "content": f"对话历史摘要:\n{summary}"}
] + preserved
return compressed, summary
def _build_summary_prompt(self, messages: List[Dict[str, str]]) -> str:
"""构建摘要提示"""
formatted = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
formatted.append(f"{role}: {content}")
return f"""请为以下对话生成简洁摘要,保留关键信息:
{chr(10).join(formatted)}
摘要应包含:
1. 讨论的主要主题
2. 做出的决定
3. 涉及的文件和技术
4. 未解决的问题(如有)
"""短期记忆的"遗忘"是设计层面的必然,而非缺陷。并非所有信息都值得保留到长期记忆,遗忘策略确保只有高价值信息进入下一层。
遗忘策略分为两类:
# short_term_memory/forgetting.py
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import math
class ImportanceLevel(Enum):
"""重要性等级"""
CRITICAL = 5 # 必须保留:密码、密钥、架构决策
HIGH = 4 # 高价值:核心业务逻辑、用户明确标注
MEDIUM = 3 # 中等价值:常规开发上下文
LOW = 2 # 低价值:临时操作、探索性代码
TRIVIAL = 1 # 可遗忘:日志、注释修改
@dataclass
class MemoryBlock:
"""记忆块"""
block_id: str
content: any
importance: ImportanceLevel = ImportanceLevel.MEDIUM
# 访问统计
access_count: int = 0
last_accessed: datetime = field(default_factory=datetime.now)
first_created: datetime = field(default_factory=datetime.now)
# 衰减
decay_factor: float = 1.0
# 标签
tags: List[str] = field(default_factory=list)
def access(self) -> None:
"""记录一次访问"""
self.access_count += 1
self.last_accessed = datetime.now()
# 访问会使衰减因子恢复
self.decay_factor = min(1.0, self.decay_factor + 0.1)
def calculate_score(self) -> float:
"""计算记忆价值分数"""
# 基础分数
base_score = self.importance.value * 20
# 访问频率分数(对数衰减)
frequency_score = math.log(1 + self.access_count) * 5
# 时间衰减分数
age = (datetime.now() - self.first_created).total_seconds()
time_score = max(0, 10 - (age / 86400)) # 每天衰减 10 分
# 衰减因子
final_score = (base_score + frequency_score + time_score) * self.decay_factor
return final_score
class ForgettingPolicy:
"""遗忘策略"""
# 系统级阈值
CRITICAL_TAGS = {"password", "secret", "key", "token", "credential", "api_key"}
HIGH_VALUE_TAGS = {"architecture", "design", "decision", "bug", "critical"}
@classmethod
def should_persist(cls, memory_block: MemoryBlock) -> bool:
"""判断是否应该持久化到长期记忆"""
# 关键标签必须持久化
if set(memory_block.tags) & cls.CRITICAL_TAGS:
return True
# 高重要性必须持久化
if memory_block.importance in [ImportanceLevel.CRITICAL, ImportanceLevel.HIGH]:
return True
# 高价值分数可以持久化
if memory_block.calculate_score() >= 70:
return True
return False
@classmethod
def should_delete(cls, memory_block: MemoryBlock) -> bool:
"""判断是否应该删除"""
# 低分数且长时间未访问
if memory_block.calculate_score() < 20:
age = (datetime.now() - memory_block.last_accessed).total_seconds()
if age > 86400 * 7: # 7 天未访问
return True
# 无访问记录的临时内容
if memory_block.access_count == 0 and memory_block.importance == ImportanceLevel.TRIVIAL:
return True
return False
@classmethod
def get_decay_rate(cls, memory_block: MemoryBlock) -> float:
"""获取衰减率"""
base_decay = 0.05 # 每天 5%
# 高重要性内容衰减慢
if memory_block.importance == ImportanceLevel.CRITICAL:
return base_decay * 0.1
elif memory_block.importance == ImportanceLevel.HIGH:
return base_decay * 0.5
# 有标签的内容衰减慢
if memory_block.tags:
return base_decay * 0.7
return base_decay本节为你提供的核心技术价值:掌握向量嵌入的生成与存储、向量数据库的选型与部署,以及基于语义相似度的跨会话检索机制。
长期记忆(Long-Term Memory)是 AI IDE 的"经验库"。它的核心职责是:
长期记忆的数据形态是向量嵌入——将文本、代码、对话等非结构化数据转换为高维向量,通过向量之间的几何关系表达语义相似性。
嵌入生成是长期记忆的入口。需要考虑:
# long_term_memory/embedding.py
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import hashlib
import tiktoken
class ChunkStrategy(Enum):
"""分块策略"""
FIXED_SIZE = "fixed_size" # 固定大小分块
SEMANTIC = "semantic" # 语义分块(基于段落/函数)
HIERARCHICAL = "hierarchical" # 层次分块(文件 -> 类 -> 函数)
@dataclass
class MemoryChunk:
"""记忆块:嵌入向量的基本单位"""
chunk_id: str = ""
content: str = ""
chunk_index: int = 0 # 在原始文档中的位置
# 嵌入向量
embedding: Optional[List[float]] = None
# 元数据
source_type: str = "" # "file", "conversation", "document", "code"
source_id: str = "" # 所属文件/会话 ID
source_path: Optional[str] = None
# 内容描述
title: str = ""
summary: str = ""
tags: List[str] = field(default_factory=list)
# 时间
created_at: datetime = field(default_factory=datetime.now)
last_accessed: datetime = field(default_factory=datetime.now)
# 统计
access_count: int = 0
relevance_score: float = 0.0 # 基于反馈调整的相关性分数
# 层级信息(用于层次分块)
hierarchy_level: int = 0 # 0: 文件, 1: 类, 2: 函数
parent_chunk_id: Optional[str] = None
child_chunk_ids: List[str] = field(default_factory=list)
class EmbeddingGenerator:
"""嵌入向量生成器"""
def __init__(
self,
model_name: str = "text-embedding-3-small",
embedding_dims: int = 1536,
batch_size: int = 100
):
self.model_name = model_name
self.embedding_dims = embedding_dims
self.batch_size = batch_size
self.encoding = tiktoken.encoding_for_model("gpt-4")
def generate(
self,
text: str,
metadata: Optional[Dict[str, Any]] = None
) -> MemoryChunk:
"""为单条文本生成嵌入"""
chunk = MemoryChunk(
chunk_id=self._generate_chunk_id(text),
content=text,
metadata=metadata or {}
)
# 调用嵌入 API(实际项目中使用 OpenAI SDK 或自托管模型)
# chunk.embedding = self._call_embedding_api(text)
# 这里使用占位符
chunk.embedding = [0.0] * self.embedding_dims
return chunk
def generate_batch(
self,
texts: List[str],
source_id: str,
source_type: str = "document"
) -> List[MemoryChunk]:
"""批量生成嵌入"""
chunks = []
for i, text in enumerate(texts):
chunk = MemoryChunk(
chunk_id=self._generate_chunk_id(f"{source_id}_{i}_{text}"),
content=text,
chunk_index=i,
source_type=source_type,
source_id=source_id
)
# 批量 API 调用(实际项目中优化)
# chunk.embedding = self._call_embedding_api(text)
chunk.embedding = [0.0] * self.embedding_dims
chunks.append(chunk)
return chunks
def _generate_chunk_id(self, content: str) -> str:
"""生成 chunk ID"""
return hashlib.md5(content.encode()).hexdigest()[:16]
class TextChunker:
"""文本分块器"""
def __init__(
self,
strategy: ChunkStrategy = ChunkStrategy.SEMANTIC,
chunk_size: int = 512,
chunk_overlap: int = 50,
max_chunk_size: int = 1024
):
self.strategy = strategy
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.max_chunk_size = max_chunk_size
def chunk_text(
self,
text: str,
metadata: Optional[Dict[str, Any]] = None
) -> List[str]:
"""将文本分割成 chunks"""
if self.strategy == ChunkStrategy.FIXED_SIZE:
return self._chunk_by_size(text)
elif self.strategy == ChunkStrategy.SEMANTIC:
return self._chunk_by_semantics(text)
elif self.strategy == ChunkStrategy.HIERARCHICAL:
return self._chunk_hierarchical(text, metadata or {})
else:
return self._chunk_by_size(text)
def _chunk_by_size(self, text: str) -> List[str]:
"""固定大小分块"""
chunks = []
tokens = self.encoding.encode(text)
start = 0
while start < len(tokens):
end = start + self.chunk_size
chunk_tokens = tokens[start:end]
chunk_text = self.encoding.decode(chunk_tokens)
chunks.append(chunk_text)
start = end - self.chunk_overlap
return chunks
def _chunk_by_semantics(self, text: str) -> List[str]:
"""语义分块(基于换行和标点)"""
# 简单实现:按段落分割
paragraphs = text.split("\n\n")
chunks = []
current_chunk = []
current_tokens = 0
for para in paragraphs:
para_tokens = len(self.encoding.encode(para))
if current_tokens + para_tokens > self.chunk_size and current_chunk:
# 保存当前 chunk
chunks.append("\n\n".join(current_chunk))
# 开始新 chunk,保留重叠
overlap_text = current_chunk[-1] if current_chunk else ""
current_chunk = [overlap_text, para] if overlap_text else [para]
current_tokens = len(self.encoding.encode(overlap_text)) + para_tokens if overlap_text else para_tokens
else:
current_chunk.append(para)
current_tokens += para_tokens
# 添加最后一个 chunk
if current_chunk:
chunks.append("\n\n".join(current_chunk))
return chunks
def _chunk_hierarchical(
self,
text: str,
metadata: Dict[str, Any]
) -> List[str]:
"""层次分块"""
chunks = []
# 层级 0: 文件级别
file_chunk = MemoryChunk(
chunk_id=f"file_{metadata.get('path', 'unknown')}",
content=text[:self.max_chunk_size], # 限制大小
hierarchy_level=0,
source_path=metadata.get("path")
)
chunks.append(text[:self.max_chunk_size])
# 层级 1 & 2: 类和函数级别(需要解析 AST,实际项目中实现)
# 这里简化为按行分割
lines = text.split("\n")
class_code = []
function_code = []
for line in lines:
if line.strip().startswith("class "):
if class_code:
chunks.append("\n".join(class_code))
class_code = []
if line.strip().startswith("def "):
if function_code:
chunks.append("\n".join(function_code))
function_code = []
class_code.append(line)
function_code.append(line)
return chunks
def chunk_code(
self,
code: str,
language: str,
metadata: Optional[Dict[str, Any]] = None
) -> List[str]:
"""代码分块(特殊处理)"""
# 代码分块的特殊考虑:
# 1. 保持语法完整性(不在语句中间分割)
# 2. 优先按函数/类分割
# 3. 注释和文档字符串单独保留
if language in ["python", "javascript", "typescript", "java", "go", "rust"]:
return self._chunk_code_by_structure(code)
else:
# 其他语言使用通用分块
return self.chunk_text(code, metadata)
def _chunk_code_by_structure(self, code: str) -> List[str]:
"""按代码结构分块(简化实现)"""
chunks = []
# 简化:按行数分割
lines = code.split("\n")
current_chunk = []
current_size = 0
for line in lines:
line_size = len(self.encoding.encode(line))
if current_size + line_size > self.chunk_size and current_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = current_chunk[-self.chunk_overlap:]
current_size = sum(len(self.encoding.encode(l)) for l in current_chunk)
current_chunk.append(line)
current_size += line_size
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks向量数据库是长期记忆的存储引擎。主流选择包括:
数据库 | 优势 | 适用场景 | 缺点 |
|---|---|---|---|
Pinecone | 托管服务、低延迟、全托管 | 快速原型、企业应用 | 成本高、数据主权 |
Milvus | 开源、可私有部署、功能全面 | 大规模数据、需要完全控制 | 运维复杂 |
Qdrant | Rust 实现、性能高、支持过滤 | 高性能需求、实时系统 | 生态较新 |
ChromaDB | 轻量级、易用、Python 原生 | 原型、小规模数据 | 功能有限 |
Weaviate | 原生支持混合搜索 | 需要 BM25 + 向量融合 | 资源消耗较高 |
# long_term_memory/vector_store.py
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from abc import ABC, abstractmethod
import numpy as np
@dataclass
class VectorSearchResult:
"""向量搜索结果"""
chunk_id: str
content: str
score: float # 相似度分数
metadata: Dict[str, Any] = field(default_factory=dict)
distance: Optional[float] = None # 欧氏距离
class VectorStore(ABC):
"""向量存储抽象基类"""
@abstractmethod
def insert(self, chunk: MemoryChunk) -> bool:
"""插入单个 chunk"""
pass
@abstractmethod
def insert_batch(self, chunks: List[MemoryChunk]) -> int:
"""批量插入"""
pass
@abstractmethod
def search(
self,
query_embedding: List[float],
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None
) -> List[VectorSearchResult]:
"""向量相似度搜索"""
pass
@abstractmethod
def delete(self, chunk_id: str) -> bool:
"""删除 chunk"""
pass
@abstractmethod
def update(self, chunk: MemoryChunk) -> bool:
"""更新 chunk"""
pass
class InMemoryVectorStore(VectorStore):
"""内存向量存储(用于测试和小规模数据)"""
def __init__(self, dimension: int = 1536, metric: str = "cosine"):
self.dimension = dimension
self.metric = metric
self._chunks: Dict[str, MemoryChunk] = {}
self._embeddings: np.ndarray = None
self._ids: List[str] = []
def insert(self, chunk: MemoryChunk) -> bool:
if chunk.embedding is None:
return False
self._chunks[chunk.chunk_id] = chunk
self._rebuild_index()
return True
def insert_batch(self, chunks: List[MemoryChunk]) -> int:
for chunk in chunks:
self.insert(chunk)
return len(chunks)
def search(
self,
query_embedding: List[float],
top_k: int = 10,
filters: Optional[Dict[str, Any]] = None
) -> List[VectorSearchResult]:
if not self._embeddings.size:
return []
query = np.array(query_embedding).reshape(1, -1)
# 计算相似度
if self.metric == "cosine":
similarities = self._cosine_similarity(query, self._embeddings)
elif self.metric == "euclidean":
similarities = -self._euclidean_distance(query, self._embeddings)
else:
raise ValueError(f"Unknown metric: {self.metric}")
# 排序
top_indices = np.argsort(similarities[0])[::-1][:top_k]
results = []
for idx in top_indices:
chunk_id = self._ids[idx]
chunk = self._chunks[chunk_id]
# 应用过滤器
if filters and not self._match_filters(chunk, filters):
continue
results.append(VectorSearchResult(
chunk_id=chunk.chunk_id,
content=chunk.content,
score=float(similarities[0][idx]),
metadata={
"source_type": chunk.source_type,
"source_id": chunk.source_id,
"created_at": chunk.created_at.isoformat()
}
))
return results
def delete(self, chunk_id: str) -> bool:
if chunk_id not in self._chunks:
return False
del self._chunks[chunk_id]
self._rebuild_index()
return True
def update(self, chunk: MemoryChunk) -> bool:
return self.insert(chunk)
def _rebuild_index(self) -> None:
"""重建索引"""
self._ids = list(self._chunks.keys())
if not self._ids:
self._embeddings = np.array([]).reshape(0, self.dimension)
return
self._embeddings = np.array([
self._chunks[bid].embedding for bid in self._ids
])
@staticmethod
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""计算余弦相似度"""
a_norm = np.linalg.norm(a, axis=1, keepdims=True)
b_norm = np.linalg.norm(b, axis=1, keepdims=True)
return np.dot(a, b.T) / (a_norm * b_norm.T + 1e-8)
@staticmethod
def _euclidean_distance(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""计算欧氏距离"""
return np.linalg.norm(a - b, axis=1, keepdims=True).squeeze()
@staticmethod
def _match_filters(chunk: MemoryChunk, filters: Dict[str, Any]) -> bool:
"""检查 chunk 是否匹配过滤器"""
for key, value in filters.items():
chunk_value = getattr(chunk, key, None)
if chunk_value is None:
return False
if isinstance(value, list):
if chunk_value not in value:
return False
elif chunk_value != value:
return False
return True长期记忆不是静态存储,而是动态演化的系统。随着时间推移,一些记忆会变得陈旧,另一些则因频繁访问而强化。
# long_term_memory/long_term_manager.py
from typing import List, Optional, Dict, Any, Callable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import threading
import time
@dataclass
class MemoryEntry:
"""长期记忆条目"""
chunk: MemoryChunk
importance_score: float = 0.5 # 0.0 - 1.0
access_count: int = 0
last_accessed: datetime = field(default_factory=datetime.now)
created_at: datetime = field(default_factory=datetime.now)
# 衰减
decay_base: float = 0.95 # 每日衰减因子
# 来源标记
is_from_session: bool = True # True: 会话结束时的压缩, False: 项目文档导入
session_id: Optional[str] = None
def access(self) -> None:
"""访问记忆,强化重要性"""
self.access_count += 1
self.last_accessed = datetime.now()
# 访问强化:分数 +5%,上限 1.0
self.importance_score = min(1.0, self.importance_score + 0.05)
def calculate_current_score(self) -> float:
"""计算当前有效分数(含时间衰减)"""
days_elapsed = (datetime.now() - self.created_at).total_seconds() / 86400
decay_factor = pow(self.decay_base, days_elapsed)
# 访问频率强化
access_bonus = min(0.2, self.access_count * 0.02)
return self.importance_score * decay_factor + access_bonus
def should_promote(self, threshold: float = 0.8) -> bool:
"""是否应该晋升到永久记忆"""
return self.calculate_current_score() >= threshold
class LongTermMemoryManager:
"""长期记忆管理器"""
def __init__(
self,
vector_store: VectorStore,
embedding_generator: EmbeddingGenerator,
config: Optional[Dict[str, Any]] = None
):
self.vector_store = vector_store
self.embedding_generator = embedding_generator
self.config = config or {}
# 存储
self._memory_index: Dict[str, MemoryEntry] = {}
# 阈值配置
self.promotion_threshold = self.config.get("promotion_threshold", 0.8)
self.deletion_threshold = self.config.get("deletion_threshold", 0.15)
self.max_memories = self.config.get("max_memories", 10000)
# 衰减配置
self.decay_interval_hours = self.config.get("decay_interval_hours", 24)
self.decay_thread = None
self._stop_decay = threading.Event()
def add_memory(
self,
content: str,
source_type: str,
source_id: str,
importance: float = 0.5,
metadata: Optional[Dict[str, Any]] = None
) -> MemoryChunk:
"""添加新记忆"""
# 生成嵌入
chunk = self.embedding_generator.generate(
text=content,
metadata={
"source_type": source_type,
"source_id": source_id,
**(metadata or {})
}
)
# 存储向量
self.vector_store.insert(chunk)
# 索引
entry = MemoryEntry(
chunk=chunk,
importance_score=importance
)
self._memory_index[chunk.chunk_id] = entry
# 检查容量
if len(self._memory_index) > self.max_memories:
self._prune_memories()
return chunk
def add_session_summary(
self,
summary: str,
session_id: str,
key_files: List[str],
key_decisions: List[str]
) -> str:
"""添加会话摘要到长期记忆"""
# 构建摘要内容
content_parts = [
f"会话摘要: {summary}",
f"涉及文件: {', '.join(key_files)}",
f"关键决策: {', '.join(key_decisions)}"
]
content = "\n".join(content_parts)
# 添加记忆
chunk = self.add_memory(
content=content,
source_type="session_summary",
source_id=session_id,
importance=0.7,
metadata={
"session_id": session_id,
"key_files": key_files,
"key_decisions": key_decisions
}
)
return chunk.chunk_id
def retrieve(
self,
query: str,
top_k: int = 10,
time_filter: Optional[Dict[str, Any]] = None,
source_filter: Optional[List[str]] = None
) -> List[VectorSearchResult]:
"""检索记忆"""
# 生成查询向量
query_chunk = self.embedding_generator.generate(query)
# 构建过滤器
filters = {}
if source_filter:
filters["source_type"] = source_filter
# 向量搜索
results = self.vector_store.search(
query_embedding=query_chunk.embedding,
top_k=top_k,
filters=filters if filters else None
)
# 访问强化
for result in results:
if result.chunk_id in self._memory_index:
self._memory_index[result.chunk_id].access()
# 重新排序(结合向量分数和重要性)
for result in results:
entry = self._memory_index.get(result.chunk_id)
if entry:
importance_factor = entry.importance_score
result.score = result.score * 0.7 + importance_factor * 0.3
results.sort(key=lambda x: x.score, reverse=True)
return results
def get_memories_for_permanent(self) -> List[MemoryEntry]:
"""获取应晋升到永久记忆的记忆"""
candidates = []
for entry in self._memory_index.values():
if entry.should_promote(self.promotion_threshold):
candidates.append(entry)
return sorted(candidates, key=lambda x: x.calculate_current_score(), reverse=True)
def delete_memory(self, chunk_id: str) -> bool:
"""删除记忆"""
if chunk_id not in self._memory_index:
return False
del self._memory_index[chunk_id]
self.vector_store.delete(chunk_id)
return True
def _prune_memories(self) -> None:
"""删除低价值记忆"""
# 计算所有记忆的当前分数
scored_memories = [
(chunk_id, entry.calculate_current_score())
for chunk_id, entry in self._memory_index.items()
]
# 按分数排序
scored_memories.sort(key=lambda x: x[1])
# 删除最低的 10%
delete_count = len(scored_memories) // 10
for chunk_id, _ in scored_memories[:delete_count]:
self.delete_memory(chunk_id)
def start_decay_process(self) -> None:
"""启动衰减进程"""
def decay_loop():
while not self._stop_decay.wait(self.decay_interval_hours * 3600):
self._apply_decay()
self.decay_thread = threading.Thread(target=decay_loop, daemon=True)
self.decay_thread.start()
def stop_decay_process(self) -> None:
"""停止衰减进程"""
self._stop_decay.set()
if self.decay_thread:
self.decay_thread.join()
def _apply_decay(self) -> None:
"""应用衰减"""
for entry in self._memory_index.values():
# 衰减基础分数
entry.importance_score *= entry.decay_base
# 删除低于阈值的记忆
to_delete = [
chunk_id for chunk_id, entry in self._memory_index.items()
if entry.calculate_current_score() < self.deletion_threshold
]
for chunk_id in to_delete:
self.delete_memory(chunk_id)本节为你提供的核心技术价值:掌握实体-关系图谱的设计、图数据库的集成、以及项目文档的解析与结构化方法,构建 AI IDE 的持久知识库。
永久记忆(Permanent Memory)是 AI IDE 的"知识库"与"档案馆"。与长期记忆的向量形式不同,永久记忆以结构化的形式存储项目知识:
永久记忆的核心特征是结构化与可推理。通过图遍历,AI 可以理解代码的整体架构而非孤立的片段。
知识图谱使用属性图模型(Property Graph),每个节点和边都可以拥有属性。

# permanent_memory/knowledge_graph.py
from typing import List, Optional, Dict, Any, Set
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import hashlib
class EntityType(Enum):
"""实体类型"""
FILE = "file"
DIRECTORY = "directory"
CLASS = "class"
FUNCTION = "function"
METHOD = "method"
VARIABLE = "variable"
CONSTANT = "constant"
INTERFACE = "interface"
MODULE = "module"
PACKAGE = "package"
TYPE = "type"
DOCUMENT = "document"
DECISION = "decision" # 架构决策
class RelationType(Enum):
"""关系类型"""
DEFINES = "defines" # 文件定义类/函数
CONTAINS = "contains" # 容器包含成员
CALLS = "calls" # 函数调用
REFERENCES = "references" # 变量引用
INHERITS = "inherits" # 类继承
IMPLEMENTS = "implements" # 类实现接口
IMPORTS = "imports" # 导入关系
DEPENDS_ON = "depends_on" # 依赖关系
OVERRIDES = "overrides" # 方法重写
DECORATED_BY = "decorated_by" # 装饰器
DOCUMENTED_BY = "documented_by" # 文档关联
RELATED_TO = "related_to" # 通用关联
@dataclass
class Entity:
"""实体节点"""
entity_id: str = ""
entity_type: EntityType = EntityType.FILE
name: str = ""
qualified_name: str = "" # 完全限定名,如 "myapp.services.UserService"
# 位置信息
file_path: Optional[str] = None
start_line: int = 0
end_line: int = 0
# 属性
properties: Dict[str, Any] = field(default_factory=dict)
# 文档
docstring: str = ""
comments: List[str] = field(default_factory=list)
# 签名(用于函数/类)
signature: str = ""
# 元数据
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
source_language: str = ""
# 重要性(用于搜索排序)
importance_score: float = 0.5
# 标签
tags: Set[str] = field(default_factory=set)
def __post_init__(self):
if not self.entity_id:
self.entity_id = self._generate_id()
def _generate_id(self) -> str:
"""生成实体 ID"""
raw = f"{self.entity_type.value}:{self.qualified_name}"
return hashlib.md5(raw.encode()).hexdigest()[:16]
def add_property(self, key: str, value: Any) -> None:
"""添加属性"""
self.properties[key] = value
self.updated_at = datetime.now()
def add_tag(self, tag: str) -> None:
"""添加标签"""
self.tags.add(tag)
self.updated_at = datetime.now()
@dataclass
class Relation:
"""关系边"""
relation_id: str = ""
source_id: str = "" # 源实体 ID
target_id: str = "" # 目标实体 ID
relation_type: RelationType = RelationType.DEFINES
# 属性
properties: Dict[str, Any] = field(default_factory=dict)
# 权重(用于路径搜索)
weight: float = 1.0
# 上下文
context: str = "" # 关系发生的代码上下文
line_number: int = 0
created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self):
if not self.relation_id:
self.relation_id = self._generate_id()
def _generate_id(self) -> str:
raw = f"{self.source_id}:{self.relation_type.value}:{self.target_id}"
return hashlib.md5(raw.encode()).hexdigest()[:16]
class KnowledgeGraph:
"""知识图谱"""
def __init__(self, graph_db: Optional[Any] = None):
# 内存存储(可替换为图数据库)
self._entities: Dict[str, Entity] = {}
self._relations: Dict[str, Relation] = {}
# 索引
self._entity_by_name: Dict[str, List[str]] = {} # name -> entity_ids
self._entity_by_type: Dict[EntityType, List[str]] = {}
self._relations_by_source: Dict[str, List[str]] = {} # source_id -> relation_ids
self._relations_by_target: Dict[str, List[str]] = {} # target_id -> relation_ids
# 图数据库后端
self.graph_db = graph_db
def add_entity(self, entity: Entity) -> str:
"""添加实体"""
self._entities[entity.entity_id] = entity
# 索引
if entity.name not in self._entity_by_name:
self._entity_by_name[entity.name] = []
self._entity_by_name[entity.name].append(entity.entity_id)
if entity.entity_type not in self._entity_by_type:
self._entity_by_type[entity.entity_type] = []
self._entity_by_type[entity.entity_type].append(entity.entity_id)
# 同步到图数据库
if self.graph_db:
self.graph_db.create_node(entity)
return entity.entity_id
def add_relation(self, relation: Relation) -> str:
"""添加关系"""
# 验证实体存在
if relation.source_id not in self._entities:
raise ValueError(f"Source entity not found: {relation.source_id}")
if relation.target_id not in self._entities:
raise ValueError(f"Target entity not found: {relation.target_id}")
self._relations[relation.relation_id] = relation
# 索引
if relation.source_id not in self._relations_by_source:
self._relations_by_source[relation.source_id] = []
self._relations_by_source[relation.source_id].append(relation.relation_id)
if relation.target_id not in self._relations_by_target:
self._relations_by_target[relation.target_id] = []
self._relations_by_target[relation.target_id].append(relation.relation_id)
# 同步到图数据库
if self.graph_db:
self.graph_db.create_edge(relation)
return relation.relation_id
def get_entity(self, entity_id: str) -> Optional[Entity]:
"""获取实体"""
return self._entities.get(entity_id)
def get_entities_by_type(self, entity_type: EntityType) -> List[Entity]:
"""按类型获取实体"""
entity_ids = self._entity_by_type.get(entity_type, [])
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
def get_entities_by_name(self, name: str) -> List[Entity]:
"""按名称获取实体"""
entity_ids = self._entity_by_name.get(name, [])
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
def get_outgoing_relations(self, entity_id: str) -> List[Relation]:
"""获取实体的出边"""
relation_ids = self._relations_by_source.get(entity_id, [])
return [self._relations[rid] for rid in relation_ids if rid in self._relations]
def get_incoming_relations(self, entity_id: str) -> List[Relation]:
"""获取实体的入边"""
relation_ids = self._relations_by_target.get(entity_id, [])
return [self._relations[rid] for rid in relation_ids if rid in self._relations]
def traverse(
self,
start_id: str,
relation_types: Optional[List[RelationType]] = None,
max_depth: int = 3,
direction: str = "outgoing"
) -> List[Entity]:
"""图遍历"""
visited = set()
result = []
queue = [(start_id, 0)]
while queue:
current_id, depth = queue.pop(0)
if current_id in visited or depth > max_depth:
continue
visited.add(current_id)
if current_id != start_id:
entity = self._entities.get(current_id)
if entity:
result.append(entity)
# 获取相邻节点
if direction == "outgoing":
relations = self.get_outgoing_relations(current_id)
elif direction == "incoming":
relations = self.get_incoming_relations(current_id)
else:
relations = self.get_outgoing_relations(current_id) + self.get_incoming_relations(current_id)
for rel in relations:
if relation_types and rel.relation_type not in relation_types:
continue
next_id = rel.target_id if direction == "outgoing" else rel.source_id
if next_id not in visited:
queue.append((next_id, depth + 1))
return result
def find_path(
self,
source_id: str,
target_id: str,
max_length: int = 5
) -> Optional[List[Entity]]:
"""查找两个实体之间的路径"""
if source_id == target_id:
return [self._entities.get(source_id)]
visited = {source_id}
queue = [(source_id, [source_id])]
while queue:
current_id, path = queue.pop(0)
if len(path) > max_length:
continue
for rel in self.get_outgoing_relations(current_id):
next_id = rel.target_id
if next_id == target_id:
entity = self._entities.get(next_id)
if entity:
return [self._entities[pid] for pid in path] + [entity]
if next_id not in visited:
visited.add(next_id)
queue.append((next_id, path + [next_id]))
return None
def get_code_context(self, entity_id: str) -> str:
"""获取实体的代码上下文"""
entity = self._entities.get(entity_id)
if not entity or not entity.file_path:
return ""
# 读取文件并提取相关代码(实际实现需要文件读取逻辑)
# 这里简化返回
return f"// {entity.entity_type.value}: {entity.qualified_name}\n// Lines {entity.start_line}-{entity.end_line}"架构决策是永久记忆的重要组成部分。ADR 记录了项目中的关键技术决策,包括背景、决策、结果和替代方案。
# permanent_memory/adr.py
from typing import List, Optional, Dict, Any
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
class DecisionStatus(Enum):
"""决策状态"""
PROPOSED = "proposed"
ACCEPTED = "accepted"
DEPRECATED = "deprecated"
SUPERSEDED = "superseded"
class DecisionImpact(Enum):
"""影响级别"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
@dataclass
class ArchitectureDecision:
"""架构决策记录"""
adr_id: str = ""
title: str = ""
# 上下文
context: str = "" # 背景和问题描述
# 决策
decision: str = "" # 决策内容
decision_maker: str = "" # 决策者
decision_date: datetime = field(default_factory=datetime.now)
# 状态
status: DecisionStatus = DecisionStatus.PROPOSED
superseded_by: Optional[str] = None
# 影响评估
impact: DecisionImpact = DecisionImpact.MEDIUM
affected_components: List[str] = field(default_factory=list)
# 结果
consequences: List[str] = field(default_factory=list) # 正面和负面后果
# 替代方案
alternatives: List[Dict[str, str]] = field(default_factory=list)
# 每项包含: {"name": "", "description": "", "reason_rejected": ""}
# 相关文档
related_adr_ids: List[str] = field(default_factory=list)
documentation_links: List[str] = field(default_factory=list)
# 标签
tags: List[str] = field(default_factory=list)
# 元数据
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
author: str = ""
def mark_accepted(self) -> None:
"""标记为已接受"""
self.status = DecisionStatus.ACCEPTED
self.updated_at = datetime.now()
def mark_deprecated(self, reason: str) -> None:
"""标记为已废弃"""
self.status = DecisionStatus.DEPRECATED
self.consequences.append(f"Deprecated: {reason}")
self.updated_at = datetime.now()
def mark_superseded(self, new_adr_id: str) -> None:
"""标记为被替代"""
self.status = DecisionStatus.SUPERSEDED
self.superseded_by = new_adr_id
self.updated_at = datetime.now()
class ADRManager:
"""ADR 管理器"""
def __init__(self, knowledge_graph: KnowledgeGraph):
self.knowledge_graph = knowledge_graph
self._adrs: Dict[str, ArchitectureDecision] = {}
self._adrs_by_component: Dict[str, List[str]] = {}
def create_adr(
self,
title: str,
context: str,
decision: str,
decision_maker: str,
impact: DecisionImpact = DecisionImpact.MEDIUM,
tags: Optional[List[str]] = None
) -> ArchitectureDecision:
"""创建新的 ADR"""
adr = ArchitectureDecision(
title=title,
context=context,
decision=decision,
decision_maker=decision_maker,
author=decision_maker
)
if tags:
adr.tags = tags
self._adrs[adr.adr_id] = adr
# 索引
for component in adr.affected_components:
if component not in self._adrs_by_component:
self._adrs_by_component[component] = []
self._adrs_by_component[component].append(adr.adr_id)
# 添加到知识图谱
entity = Entity(
entity_type=EntityType.DECISION,
name=title,
qualified_name=f"ADR-{adr.adr_id}",
properties={
"status": adr.status.value,
"impact": adr.impact.value,
"decision_date": adr.decision_date.isoformat()
},
docstring=f"{context}\n\nDecision: {decision}",
tags=set(adr.tags) if adr.tags else set()
)
self.knowledge_graph.add_entity(entity)
return adr
def get_adr(self, adr_id: str) -> Optional[ArchitectureDecision]:
"""获取 ADR"""
return self._adrs.get(adr_id)
def get_adr_by_component(self, component: str) -> List[ArchitectureDecision]:
"""获取组件相关的 ADR"""
adr_ids = self._adrs_by_component.get(component, [])
return [self._adrs[aid] for aid in adr_ids if aid in self._adrs]
def get_active_adrs(self) -> List[ArchitectureDecision]:
"""获取当前有效的 ADR"""
return [
adr for adr in self._adrs.values()
if adr.status in [DecisionStatus.ACCEPTED, DecisionStatus.PROPOSED]
]
def search_adrs(self, query: str) -> List[ArchitectureDecision]:
"""搜索 ADR"""
results = []
query_lower = query.lower()
for adr in self._adrs.values():
if (query_lower in adr.title.lower() or
query_lower in adr.context.lower() or
query_lower in adr.decision.lower() or
any(query_lower in tag.lower() for tag in adr.tags)):
results.append(adr)
return results
def export_markdown(self, adr: ArchitectureDecision) -> str:
"""导出为 Markdown 格式"""
lines = [
f"# {adr.title}",
"",
f"**ADR ID:** {adr.adr_id}",
f"**Status:** {adr.status.value}",
f"**Date:** {adr.decision_date.strftime('%Y-%m-%d')}",
f"**Decider:** {adr.decision_maker}",
f"**Impact:** {adr.impact.value}",
"",
"## Context",
adr.context,
"",
"## Decision",
adr.decision,
"",
"## Consequences",
]
for cons in adr.consequences:
lines.append(f"- {cons}")
if adr.alternatives:
lines.append("")
lines.append("## Alternatives Considered")
for alt in adr.alternatives:
lines.append(f"### {alt['name']}")
lines.append(alt['description'])
lines.append(f"*Rejected because: {alt['reason_rejected']}*")
lines.append("")
if adr.related_adr_ids:
lines.append("")
lines.append("## Related ADRs")
for related_id in adr.related_adr_ids:
lines.append(f"- {related_id}")
return "\n".join(lines)本节为你提供的核心技术价值:掌握向量检索、BM25 关键词匹配与图遍历三种检索方式的融合策略,实现精准、全面的记忆检索。
单一检索方式难以覆盖所有查询场景。混合检索通过组合多种检索策略,扬长避短:
检索方式 | 优势 | 劣势 | 最佳场景 |
|---|---|---|---|
向量检索 | 语义理解、相似性 | 对关键词不敏感 | “类似功能的实现”、“上次做的什么” |
BM25 | 关键词精确匹配 | 无语义理解 | “包含 X 函数的文件”、“名为 Y 的类” |
图遍历 | 结构关系推理 | 需要图谱存在 | “谁调用了 X”、“X 继承自什么” |

BM25(Best Matching 25)是一种经典的关键词检索算法,比 TF-IDF 更鲁棒。
# retrieval/bm25.py
from typing import List, Dict, Any, Tuple
import math
from collections import Counter
import re
class BM25:
"""BM25 检索算法实现"""
def __init__(
self,
k1: float = 1.5, # 词频饱和参数
b: float = 0.75 # 文档长度归一化参数
):
self.k1 = k1
self.b = b
self.corpus_size = 0
self.avgdl = 0 # 平均文档长度
self.doc_lengths: List[int] = []
self.doc_freqs: Dict[str, int] = {} # 词 -> 文档频率
self.idf: Dict[str, float] = {}
self.doc_term_freqs: List[Dict[str, int]] = [] # 每个文档的词频
self.documents: List[str] = []
def fit(self, corpus: List[str]) -> None:
"""构建索引"""
self.corpus_size = len(corpus)
self.documents = corpus
self.doc_lengths = []
self.doc_term_freqs = []
self.doc_freqs = Counter()
for doc in corpus:
# 分词
terms = self._tokenize(doc)
self.doc_lengths.append(len(terms))
# 词频统计
term_freqs = Counter(terms)
self.doc_term_freqs.append(term_freqs)
# 文档频率
for term in set(terms):
self.doc_freqs[term] += 1
# 计算平均文档长度
self.avgdl = sum(self.doc_lengths) / self.corpus_size if self.corpus_size else 0
# 计算 IDF
self._calculate_idf()
def _tokenize(self, text: str) -> List[str]:
"""分词"""
# 简单分词:转小写,提取字母数字词
text = text.lower()
tokens = re.findall(r'\b[a-z0-9_]+\b', text)
return tokens
def _calculate_idf(self) -> None:
"""计算 IDF 值"""
for term, freq in self.doc_freqs.items():
# IDF 公式:log((N - n + 0.5) / (n + 0.5) + 1)
idf = math.log(
(self.corpus_size - freq + 0.5) / (freq + 0.5) + 1
)
self.idf[term] = idf
def search(
self,
query: str,
top_k: int = 10
) -> List[Tuple[int, float]]:
"""搜索,返回 (doc_index, score) 列表"""
query_terms = self._tokenize(query)
if not query_terms:
return []
scores = []
for i, doc_tf in enumerate(self.doc_term_freqs):
score = self._calculate_score(doc_tf, query_terms, self.doc_lengths[i])
if score > 0:
scores.append((i, score))
# 排序
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def _calculate_score(
self,
doc_tf: Dict[str, int],
query_terms: List[str],
doc_length: int
) -> float:
"""计算单文档的 BM25 分数"""
score = 0.0
doc_len_norm = doc_length / (self.avgdl + 1e-8)
for term in query_terms:
if term not in doc_tf:
continue
tf = doc_tf[term]
idf = self.idf.get(term, 0)
# BM25 公式
term_score = idf * (
(tf * (self.k1 + 1)) /
(tf + self.k1 * (1 - self.b + self.b * doc_len_norm))
)
score += term_score
return score
class BM25Indexer:
"""BM25 索引管理器"""
def __init__(self):
self.bm25 = BM25()
self.documents: List[Dict[str, Any]] = [] # 原始文档信息
def add_documents(
self,
documents: List[Dict[str, str]],
id_field: str = "id",
content_field: str = "content"
) -> None:
"""添加文档到索引"""
self.documents = documents
contents = [doc.get(content_field, "") for doc in documents]
self.bm25.fit(contents)
def search(
self,
query: str,
top_k: int = 10
) -> List[Dict[str, Any]]:
"""搜索文档"""
results = self.bm25.search(query, top_k)
return [
{
**self.documents[idx],
"bm25_score": score
}
for idx, score in results
]
def search_with_filter(
self,
query: str,
filter_fn,
top_k: int = 10
) -> List[Dict[str, Any]]:
"""带过滤器的搜索"""
# 先过滤
filtered_docs = [doc for doc in self.documents if filter_fn(doc)]
if not filtered_docs:
return []
# 重建临时索引
temp_bm25 = BM25()
temp_bm25.fit([doc.get("content", "") for doc in filtered_docs])
results = temp_bm25.search(query, top_k)
return [
{
**filtered_docs[idx],
"bm25_score": score
}
for idx, score in results
]# retrieval/hybrid_search.py
from typing import List, Optional, Dict, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
import numpy as np
class SearchStrategy(Enum):
"""搜索策略"""
VECTOR_ONLY = "vector_only"
BM25_ONLY = "bm25_only"
GRAPH_ONLY = "graph_only"
HYBRID = "hybrid"
ADAPTIVE = "adaptive" # 根据查询自动选择
@dataclass
class SearchResult:
"""搜索结果"""
chunk_id: str
content: str
score: float
source_type: str = "" # "vector", "bm25", "graph"
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
# 来源信息
source_index: Optional[int] = None # 在原始结果中的位置
vector_score: Optional[float] = None
bm25_score: Optional[float] = None
graph_score: Optional[float] = None
@dataclass
class HybridSearchConfig:
"""混合搜索配置"""
# 权重配置
vector_weight: float = 0.4
bm25_weight: float = 0.3
graph_weight: float = 0.3
# 向量配置
vector_top_k: int = 50
vector_score_threshold: float = 0.5
# BM25 配置
bm25_top_k: int = 50
bm25_score_threshold: float = 0.1
# 图搜索配置
graph_max_depth: int = 3
graph_score_threshold: float = 0.3
# 重排序
use_reranker: bool = True
reranker_top_k: int = 10
# 分数标准化方法
score_normalization: str = "min_max" # "min_max", "z_score", "rank"
class HybridSearchEngine:
"""混合搜索引擎"""
def __init__(
self,
vector_store: Any, # VectorStore
bm25_indexer: BM25Indexer,
knowledge_graph: Any, # KnowledgeGraph
config: Optional[HybridSearchConfig] = None
):
self.vector_store = vector_store
self.bm25_indexer = bm25_indexer
self.knowledge_graph = knowledge_graph
self.config = config or HybridSearchConfig()
def search(
self,
query: str,
query_vector: List[float],
top_k: int = 10,
strategy: SearchStrategy = SearchStrategy.HYBRID,
filters: Optional[Dict[str, Any]] = None
) -> List[SearchResult]:
"""执行混合搜索"""
if strategy == SearchStrategy.VECTOR_ONLY:
return self._vector_search(query_vector, top_k, filters)
elif strategy == SearchStrategy.BM25_ONLY:
return self._bm25_search(query, top_k, filters)
elif strategy == SearchStrategy.GRAPH_ONLY:
return self._graph_search(query, top_k, filters)
elif strategy == SearchStrategy.ADAPTIVE:
strategy = self._select_strategy(query)
return self.search(query, query_vector, top_k, strategy, filters)
else: # HYBRID
return self._hybrid_search(query, query_vector, top_k, filters)
def _vector_search(
self,
query_vector: List[float],
top_k: int,
filters: Optional[Dict[str, Any]]
) -> List[SearchResult]:
"""纯向量搜索"""
vector_results = self.vector_store.search(
query_embedding=query_vector,
top_k=self.config.vector_top_k,
filters=filters
)
# 标准化分数
scores = [r.score for r in vector_results]
normalized = self._normalize_scores(scores)
results = []
for r, norm_score in zip(vector_results, normalized):
if norm_score >= self.config.vector_score_threshold:
results.append(SearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=norm_score,
source_type="vector",
metadata=r.metadata,
vector_score=r.score
))
return results[:top_k]
def _bm25_search(
self,
query: str,
top_k: int,
filters: Optional[Dict[str, Any]]
) -> List[SearchResult]:
"""纯 BM25 搜索"""
bm25_results = self.bm25_indexer.search(
query=query,
top_k=self.config.bm25_top_k
)
# 应用过滤器
if filters:
bm25_results = [
r for r in bm25_results
if all(r.get(k) == v for k, v in filters.items())
]
# 标准化分数
scores = [r.get("bm25_score", 0) for r in bm25_results]
normalized = self._normalize_scores(scores)
results = []
for r, norm_score in zip(bm25_results, normalized):
if norm_score >= self.config.bm25_score_threshold:
results.append(SearchResult(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=norm_score,
source_type="bm25",
metadata=r,
bm25_score=r.get("bm25_score")
))
return results[:top_k]
def _graph_search(
self,
query: str,
top_k: int,
filters: Optional[Dict[str, Any]]
) -> List[SearchResult]:
"""图搜索"""
# 图搜索需要实体 ID,这里简化处理
# 实际实现需要:1) 从查询中提取实体 2) 在图谱中定位 3) 图遍历
results = []
# 示例:从知识图谱获取相关实体
entities = self.knowledge_graph.get_entities_by_name(query)
for entity in entities[:top_k]:
results.append(SearchResult(
chunk_id=entity.entity_id,
content=entity.docstring or entity.name,
score=entity.importance_score,
source_type="graph",
metadata={
"entity_type": entity.entity_type.value,
"qualified_name": entity.qualified_name
},
graph_score=entity.importance_score
))
return results[:top_k]
def _hybrid_search(
self,
query: str,
query_vector: List[float],
top_k: int,
filters: Optional[Dict[str, Any]]
) -> List[SearchResult]:
"""混合搜索"""
# 并行执行三种搜索
vector_results = self._vector_search(query_vector, self.config.vector_top_k, filters)
bm25_results = self._bm25_search(query, self.config.bm25_top_k, filters)
graph_results = self._graph_search(query, self.config.graph_max_depth * 10, filters)
# 合并结果
all_results: Dict[str, SearchResult] = {}
for r in vector_results:
r.vector_score = r.score
all_results[r.chunk_id] = r
for r in bm25_results:
if r.chunk_id in all_results:
existing = all_results[r.chunk_id]
existing.bm25_score = r.score
existing.metadata.update(r.metadata)
else:
r.bm25_score = r.score
all_results[r.chunk_id] = r
for r in graph_results:
if r.chunk_id in all_results:
existing = all_results[r.chunk_id]
existing.graph_score = r.score
existing.metadata.update(r.metadata)
else:
r.graph_score = r.score
all_results[r.chunk_id] = r
# 计算综合分数
for r in all_results.values():
v_weight = self.config.vector_weight
b_weight = self.config.bm25_weight
g_weight = self.config.graph_weight
# 处理缺失分数
v_score = r.vector_score if r.vector_score is not None else 0
b_score = r.bm25_score if r.bm25_score is not None else 0
g_score = r.graph_score if r.graph_score is not None else 0
# 加权求和
r.score = (
v_weight * v_score +
b_weight * b_score +
g_weight * g_score
)
# 排序并返回 top_k
sorted_results = sorted(all_results.values(), key=lambda x: x.score, reverse=True)
# 可选:重排序
if self.config.use_reranker and len(sorted_results) > self.config.reranker_top_k:
sorted_results = self._rerank(query, sorted_results)
return sorted_results[:top_k]
def _normalize_scores(self, scores: List[float]) -> List[float]:
"""标准化分数"""
if not scores:
return []
if self.config.score_normalization == "min_max":
min_s = min(scores)
max_s = max(scores)
if max_s - min_s < 1e-8:
return [0.5] * len(scores)
return [(s - min_s) / (max_s - min_s) for s in scores]
elif self.config.score_normalization == "z_score":
mean = sum(scores) / len(scores)
std = (sum((s - mean) ** 2 for s in scores) / len(scores)) ** 0.5
if std < 1e-8:
return [0.5] * len(scores)
return [(s - mean) / std for s in scores]
else: # rank
sorted_scores = sorted(scores, reverse=True)
rank_map = {s: i for i, s in enumerate(sorted_scores)}
return [1 - rank_map[s] / len(scores) for s in scores]
def _rerank(
self,
query: str,
results: List[SearchResult]
) -> List[SearchResult]:
"""重排序(简化实现)"""
# 实际实现使用 Cross-Encoder 或其他重排序模型
# 这里简化:保持原顺序
return results[:self.config.reranker_top_k]
def _select_strategy(self, query: str) -> SearchStrategy:
"""根据查询选择策略"""
query_lower = query.lower()
# 关键词特征明显 -> BM25
if any(kw in query_lower for kw in ["file", "function", "class", "name", "called"]):
return SearchStrategy.BM25_ONLY
# 关系查询 -> 图搜索
if any(kw in query_lower for kw in ["who", "which", "where", "call", "inherit", "use"]):
return SearchStrategy.GRAPH_ONLY
# 语义相似 -> 向量
if any(kw in query_lower for kw in ["similar", "like", "类似", "类似的"]):
return SearchStrategy.VECTOR_ONLY
# 默认混合
return SearchStrategy.HYBRID本节为你提供的核心技术价值:掌握基于重要性评分的选择性记忆保留策略,以及使用 LLM 生成压缩摘要的技术实现。
AI IDE 的资源是有限的,但记忆的价值是不同的。记忆压缩解决的核心问题:
记忆压缩的核心思想是:只保留最重要的信息,用压缩形式存储。
# compression/importance_scorer.py
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import numpy as np
class ValueDimension(Enum):
"""价值维度"""
NOVELTY = "novelty" # 新颖性
ACCURACY = "accuracy" # 准确性
RELEVANCE = "relevance" # 相关性
COMPLETENESS = "completeness" # 完整性
ACTIONABILITY = "actionability" # 可操作性
@dataclass
class ScoreBreakdown:
"""分数分解"""
novelty: float = 0.0
accuracy: float = 0.0
relevance: float = 0.0
completeness: float = 0.0
actionability: float = 0.0
total: float = 0.0
def weighted_sum(self, weights: Dict[ValueDimension, float]) -> float:
"""加权求和"""
return (
weights.get(ValueDimension.NOVELTY, 0.2) * self.novelty +
weights.get(ValueDimension.ACCURACY, 0.2) * self.accuracy +
weights.get(ValueDimension.RELEVANCE, 0.3) * self.relevance +
weights.get(ValueDimension.COMPLETENESS, 0.15) * self.completeness +
weights.get(ValueDimension.ACTIONABILITY, 0.15) * self.actionability
)
class ImportanceScorer:
"""重要性评分器"""
def __init__(
self,
llm_client: Optional[Any] = None,
use_llm_judgment: bool = False
):
self.llm_client = llm_client
self.use_llm_judgment = use_llm_judgment
# 权重配置
self.weights = {
ValueDimension.NOVELTY: 0.2,
ValueDimension.ACCURACY: 0.2,
ValueDimension.RELEVANCE: 0.3,
ValueDimension.COMPLETENESS: 0.15,
ValueDimension.ACTIONABILITY: 0.15
}
def score(
self,
memory_content: str,
context: Dict[str, Any],
conversation_history: Optional[List[Dict]] = None
) -> ScoreBreakdown:
"""评估记忆的重要性"""
breakdown = ScoreBreakdown()
# 1. 新颖性评分
breakdown.novelty = self._score_novelty(memory_content, conversation_history or [])
# 2. 准确性评分(基于反馈)
breakdown.accuracy = self._score_accuracy(context)
# 3. 相关性评分
breakdown.relevance = self._score_relevance(memory_content, context)
# 4. 完整性评分
breakdown.completeness = self._score_completeness(memory_content)
# 5. 可操作性评分
breakdown.actionability = self._score_actionability(memory_content, context)
# 计算总分
breakdown.total = breakdown.weighted_sum(self.weights)
return breakdown
def _score_novelty(
self,
content: str,
history: List[Dict]
) -> float:
"""新颖性:与历史记忆的差异程度"""
if not history:
return 0.8 # 无历史,默认高新颖性
# 简化的基于关键词的新颖性
content_keywords = set(content.lower().split())
max_similarity = 0.0
for past in history[-10:]: # 只比较最近 10 条
past_keywords = set(past.get("content", "").lower().split())
if content_keywords and past_keywords:
intersection = len(content_keywords & past_keywords)
union = len(content_keywords | past_keywords)
similarity = intersection / union if union > 0 else 0
max_similarity = max(max_similarity, similarity)
# 新颖性 = 1 - 相似度
novelty = 1.0 - max_similarity
return novelty
def _score_accuracy(self, context: Dict[str, Any]) -> float:
"""准确性:基于用户反馈的准确性评分"""
# 有正向反馈 -> 高准确性
if context.get("user_feedback") == "positive":
return 0.9
# 有负面反馈 -> 低准确性
if context.get("user_feedback") == "negative":
return 0.3
# 无反馈 -> 中等准确性
return 0.6
def _score_relevance(
self,
content: str,
context: Dict[str, Any]
) -> float:
"""相关性:与当前任务的关联程度"""
current_task = context.get("current_task", "")
if not current_task:
return 0.5
# 简单的关键词匹配
task_keywords = set(current_task.lower().split())
content_keywords = set(content.lower().split())
if not task_keywords:
return 0.5
overlap = len(task_keywords & content_keywords)
relevance = overlap / len(task_keywords)
return min(1.0, relevance + 0.3) # 基础分 0.3
def _score_completeness(self, content: str) -> float:
"""完整性:信息是否完整(有无截断、是否包含结论)"""
# 长度检查
if len(content) < 50:
return 0.3
# 完整性指标
has_conclusion = any(kw in content.lower() for kw in ["因此", "所以", "结果", "结论", "thus", "therefore", "result", "conclusion"])
has_context = any(kw in content.lower() for kw in ["因为", "由于", "背景", "because", "background"])
completeness = 0.5
if has_conclusion:
completeness += 0.25
if has_context:
completeness += 0.25
return min(1.0, completeness)
def _score_actionability(
self,
content: str,
context: Dict[str, Any]
) -> float:
"""可操作性:能否直接用于下一步行动"""
# 可操作内容特征
action_keywords = [
"def ", "class ", "function ", "method ",
"import ", "from ", "require ",
"config", "setting", "option",
"fix", "bug", "issue", "problem"
]
# 不可操作内容特征
no_action_keywords = [
"maybe", "perhaps", "might", "could be",
"unclear", "unknown", "todo"
]
content_lower = content.lower()
# 基础分
actionability = 0.5
# 增加分数
for kw in action_keywords:
if kw in content_lower:
actionability += 0.1
# 降低分数
for kw in no_action_keywords:
if kw in content_lower:
actionability -= 0.15
return max(0.0, min(1.0, actionability))
def rank_memories(
self,
memories: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""对记忆进行重要性排序"""
scored_memories = []
for memory in memories:
breakdown = self.score(
memory.get("content", ""),
{**context, **memory.get("context", {})},
memory.get("history", [])
)
scored_memories.append({
**memory,
"importance_breakdown": breakdown,
"importance_score": breakdown.total
})
# 按重要性排序
scored_memories.sort(key=lambda x: x["importance_score"], reverse=True)
return scored_memories本节为你提供的核心技术价值:构建完整的隐私安全保障体系,包括敏感信息识别、加密存储、访问控制与合规审计。

# privacy/sensitive_detector.py
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from enum import Enum
import re
import hashlib
class SensitiveType(Enum):
"""敏感信息类型"""
PASSWORD = "password"
API_KEY = "api_key"
PRIVATE_KEY = "private_key"
TOKEN = "token"
SECRET = "secret"
CREDENTIAL = "credential"
PII_NAME = "pii_name"
PII_EMAIL = "pii_email"
PII_PHONE = "pii_phone"
PII_ID = "pii_id"
PII_ADDRESS = "pii_address"
CREDIT_CARD = "credit_card"
DATABASE_URL = "database_url"
CUSTOM = "custom"
@dataclass
class SensitiveMatch:
"""敏感信息匹配结果"""
match_type: SensitiveType
matched_text: str
start: int
end: int
confidence: float # 0.0 - 1.0
mask_value: str = "" # 脱敏后的值
class SensitiveDetector:
"""敏感信息检测器"""
# 预定义的正则模式
PATTERNS = {
# 密码
SensitiveType.PASSWORD: [
r'password\s*[=:]\s*["\']?([^"\'\s,}]+)',
r'passwd\s*[=:]\s*["\']?([^"\'\s,}]+)',
r'pwd\s*[=:]\s*["\']?([^"\'\s,}]+)',
],
# API Keys
SensitiveType.API_KEY: [
r'api[_-]?key\s*[=:]\s*["\']?([a-zA-Z0-9_\-]{20,})',
r'apikey\s*[=:]\s*["\']?([a-zA-Z0-9_\-]{20,})',
r'api[_-]?secret\s*[=:]\s*["\']?([a-zA-Z0-9_\-]{20,})',
],
# Private Keys
SensitiveType.PRIVATE_KEY: [
r'-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----',
r'-----BEGIN\s+EC\s+PRIVATE\s+KEY-----',
r'-----BEGIN\s+OPENSSH\s+PRIVATE\s+KEY-----',
],
# Tokens
SensitiveType.TOKEN: [
r'bearer\s+[a-zA-Z0-9_\-\.]+',
r'token\s*[=:]\s*["\']?([a-zA-Z0-9_\-\.]{20,})',
r'access[_-]?token\s*[=:]\s*["\']?([a-zA-Z0-9_\-\.]{20,})',
],
# 数据库 URL
SensitiveType.DATABASE_URL: [
r'mongodb[+]?://[^:]+:[^@]+@',
r'mysql://[^:]+:[^@]+@',
r'postgresql://[^:]+:[^@]+@',
r'redis://[^:]+:[^@]+@',
r'sqlite:///[^?]+\?.*password=',
],
# 信用卡
SensitiveType.CREDIT_CARD: [
r'\b(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|3[47][0-9]{13})\b',
r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
],
# 邮箱
SensitiveType.PII_EMAIL: [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
],
# 电话号码
SensitiveType.PII_PHONE: [
r'\b1[3-9]\d{9}\b', # 中国手机号
r'\b\d{3}[-\s]?\d{3,4}[-\s]?\d{4}\b',
],
}
def __init__(self):
# 编译正则表达式
self._compiled_patterns: Dict[SensitiveType, List[re.Pattern]] = {}
for sens_type, patterns in self.PATTERNS.items():
self._compiled_patterns[sens_type] = [
re.compile(p, re.IGNORECASE) for p in patterns
]
def detect(self, text: str) -> List[SensitiveMatch]:
"""检测文本中的敏感信息"""
matches = []
for sens_type, patterns in self._compiled_patterns.items():
for pattern in patterns:
for match in pattern.finditer(text):
matched_text = match.group(0) if match.lastindex is None else match.group(1) or match.group(0)
# 计算置信度
confidence = self._calculate_confidence(sens_type, matched_text, text)
# 生成脱敏值
mask_value = self._generate_mask(sens_type, matched_text)
matches.append(SensitiveMatch(
match_type=sens_type,
matched_text=matched_text,
start=match.start(),
end=match.end(),
confidence=confidence,
mask_value=mask_value
))
# 按位置排序
matches.sort(key=lambda x: x.start)
return matches
def detect_and_redact(self, text: str) -> Tuple[str, List[SensitiveMatch]]:
"""检测并脱敏"""
matches = self.detect(text)
if not matches:
return text, []
# 从后向前替换,避免位置偏移
result = text
offset = 0
for match in matches:
start = match.start + offset
end = match.end + offset
result = result[:start] + match.mask_value + result[end:]
offset += len(match.mask_value) - (match.end - match.start)
return result, matches
def _calculate_confidence(
self,
sens_type: SensitiveType,
matched_text: str,
context: str
) -> float:
"""计算检测置信度"""
base_confidence = 0.8
# 上下文增强
context_lower = context.lower()
# 明确的变量名
explicit_names = {
SensitiveType.PASSWORD: ["password", "passwd", "pwd", "secret"],
SensitiveType.API_KEY: ["api_key", "apikey", "api_secret", "secret"],
SensitiveType.TOKEN: ["token", "access_token", "auth_token"],
}
if sens_type in explicit_names:
for name in explicit_names[sens_type]:
if name in context_lower:
return 0.95
# 格式验证
if sens_type == SensitiveType.CREDIT_CARD:
if self._validate_luhn(matched_text.replace("-", "").replace(" ", "")):
return 0.95
# 熵检测(高熵字符串更可能是密钥)
if sens_type in [SensitiveType.API_KEY, SensitiveType.TOKEN]:
entropy = self._calculate_entropy(matched_text)
if entropy > 4.0:
return base_confidence + 0.1
return base_confidence
def _generate_mask(self, sens_type: SensitiveType, matched_text: str) -> str:
"""生成脱敏值"""
masks = {
SensitiveType.PASSWORD: "[REDACTED_PASSWORD]",
SensitiveType.API_KEY: "[REDACTED_API_KEY]",
SensitiveType.PRIVATE_KEY: "[REDACTED_PRIVATE_KEY]",
SensitiveType.TOKEN: "[REDACTED_TOKEN]",
SensitiveType.SECRET: "[REDACTED_SECRET]",
SensitiveType.CREDENTIAL: "[REDACTED_CREDENTIAL]",
SensitiveType.PII_NAME: "[REDACTED_NAME]",
SensitiveType.PII_EMAIL: self._mask_email(matched_text),
SensitiveType.PII_PHONE: self._mask_phone(matched_text),
SensitiveType.CREDIT_CARD: self._mask_credit_card(matched_text),
SensitiveType.DATABASE_URL: "[REDACTED_DB_URL]",
}
return masks.get(sens_type, "[REDACTED]")
def _mask_email(self, email: str) -> str:
"""脱敏邮箱"""
parts = email.split("@")
if len(parts) != 2:
return "[REDACTED_EMAIL]"
username = parts[0]
masked_username = username[0] + "*" * (len(username) - 2) + username[-1] if len(username) > 2 else "***"
return f"{masked_username}@{parts[1]}"
def _mask_phone(self, phone: str) -> str:
"""脱敏手机号"""
digits = re.sub(r'\D', '', phone)
if len(digits) >= 11:
return digits[:3] + "****" + digits[-4:]
return "***-****-****"
def _mask_credit_card(self, card: str) -> str:
"""脱敏信用卡号"""
digits = re.sub(r'\D', '', card)
if len(digits) >= 13:
return digits[:4] + " **** **** " + digits[-4:]
return "****-****-****-****"
@staticmethod
def _validate_luhn(card_number: str) -> bool:
"""Luhn 算法验证"""
digits = [int(d) for d in card_number if d.isdigit()]
checksum = 0
for i, digit in enumerate(reversed(digits)):
if i % 2 == 1:
digit *= 2
if digit > 9:
digit -= 9
checksum += digit
return checksum % 10 == 0
@staticmethod
def _calculate_entropy(text: str) -> float:
"""计算字符串熵"""
if not text:
return 0.0
import math
from collections import Counter
length = len(text)
counter = Counter(text)
entropy = 0.0
for count in counter.values():
p = count / length
entropy -= p * math.log2(p)
return entropy# privacy/encrypted_storage.py
from typing import Optional, Dict, Any
import hashlib
import hmac
import os
from datetime import datetime
from dataclasses import dataclass
import json
@dataclass
class EncryptedData:
"""加密数据"""
ciphertext: bytes
iv: bytes # 初始化向量
auth_tag: bytes # 认证标签
key_id: str # 密钥标识符
encrypted_at: datetime
class EncryptionKeyManager:
"""加密密钥管理器"""
def __init__(self, master_key: Optional[bytes] = None):
# 实际应用中,master_key 应从密钥管理服务获取
self.master_key = master_key or os.urandom(32)
self.key_cache: Dict[str, bytes] = {}
def derive_key(self, purpose: str, context: str = "") -> bytes:
"""派生专用密钥"""
key_material = f"{purpose}:{context}".encode()
raw_key = hmac.new(self.master_key, key_material, hashlib.sha256).digest()
return raw_key
def get_data_key(self, data_id: str) -> bytes:
"""获取数据密钥"""
if data_id not in self.key_cache:
self.key_cache[data_id] = self.derive_key("data_encryption", data_id)
return self.key_cache[data_id]
def rotate_master_key(self, new_master_key: bytes) -> None:
"""轮换主密钥(需要重新加密所有数据)"""
self.master_key = new_master_key
self.key_cache.clear()
class EncryptedMemoryStorage:
"""加密记忆存储"""
def __init__(
self,
key_manager: EncryptionKeyManager,
storage_backend: Optional[Any] = None
):
self.key_manager = key_manager
self.storage_backend = storage_backend # 可接入实际的持久化存储
# 内存缓存(已解密)
self._cache: Dict[str, Any] = {}
def store(
self,
memory_id: str,
data: Dict[str, Any],
sensitivity_level: str = "medium"
) -> None:
"""加密存储记忆"""
import base64
# 生成随机 IV
iv = os.urandom(16)
# 获取数据密钥
data_key = self.key_manager.get_data_key(memory_id)
# JSON 序列化
plaintext = json.dumps(data, ensure_ascii=False).encode()
# AES-GCM 加密
ciphertext, auth_tag = self._aes_gcm_encrypt(plaintext, data_key, iv)
# 构建加密数据
encrypted = EncryptedData(
ciphertext=ciphertext,
iv=iv,
auth_tag=auth_tag,
key_id=hashlib.sha256(data_key).hexdigest()[:16],
encrypted_at=datetime.now()
)
# 存储到后端
if self.storage_backend:
self.storage_backend.save(memory_id, {
"ciphertext": base64.b64encode(ciphertext).decode(),
"iv": base64.b64encode(iv).decode(),
"auth_tag": base64.b64encode(auth_tag).decode(),
"key_id": encrypted.key_id,
"encrypted_at": encrypted.encrypted_at.isoformat()
})
# 内存缓存(如果是低敏感数据)
if sensitivity_level == "low":
self._cache[memory_id] = data
def retrieve(self, memory_id: str) -> Optional[Dict[str, Any]]:
"""解密获取记忆"""
import base64
# 先检查缓存
if memory_id in self._cache:
return self._cache[memory_id]
# 从后端获取
if not self.storage_backend:
return None
encrypted_data = self.storage_backend.load(memory_id)
if not encrypted_data:
return None
# 解密
ciphertext = base64.b64decode(encrypted_data["ciphertext"])
iv = base64.b64decode(encrypted_data["iv"])
auth_tag = base64.b64decode(encrypted_data["auth_tag"])
# 获取密钥
# 实际应用中,需要从密钥管理服务获取
data_key = self.key_manager.get_data_key(memory_id)
# AES-GCM 解密
plaintext = self._aes_gcm_decrypt(ciphertext, data_key, iv, auth_tag)
return json.loads(plaintext.decode())
@staticmethod
def _aes_gcm_encrypt(
plaintext: bytes,
key: bytes,
iv: bytes
) -> tuple:
"""AES-GCM 加密"""
# 实际实现使用 cryptography 库
# from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# aesgcm = AESGCM(key)
# ciphertext = aesgcm.encrypt(iv, plaintext, None)
# 格式:ciphertext = encrypted || auth_tag (16 bytes)
# return ciphertext[:-16], ciphertext[-16:]
# 占位实现
return plaintext + b'\x00' * 16, b'\x00' * 16
@staticmethod
def _aes_gcm_decrypt(
ciphertext: bytes,
key: bytes,
iv: bytes,
auth_tag: bytes
) -> bytes:
"""AES-GCM 解密"""
# 实际实现使用 cryptography 库
# from cryptography.hazmat.primitives.ciphers.aead import AESGCM
# aesgcm = AESGCM(key)
# return aesgcm.decrypt(iv, ciphertext + auth_tag, None)
# 占位实现
return ciphertext[:-16]附录(Appendix):
以下是完整的 Memory System 实现,整合了本文讲解的所有核心组件:
# memory_system.py - 三层记忆架构完整实现
"""
AI IDE Memory System - 三层记忆架构完整实现
作者:HOS(安全风信子)
版本:1.0.0
"""
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import threading
import asyncio
class MemoryLevel(Enum):
"""记忆层级"""
SHORT_TERM = "short_term"
LONG_TERM = "long_term"
PERMANENT = "permanent"
@dataclass
class MemoryItem:
"""记忆条目"""
memory_id: str
content: str
level: MemoryLevel
created_at: datetime = field(default_factory=datetime.now)
last_accessed: datetime = field(default_factory=datetime.now)
access_count: int = 0
importance_score: float = 0.5
metadata: Dict[str, Any] = field(default_factory=dict)
def access(self) -> None:
"""记录访问"""
self.access_count += 1
self.last_accessed = datetime.now()
class MemorySystem:
"""三层记忆系统主控制器"""
def __init__(
self,
config: Optional[Dict[str, Any]] = None
):
self.config = config or {}
# 初始化各层记忆
self.short_term = ShortTermMemory(self)
self.long_term = LongTermMemory(self, self.config.get("vector_store"))
self.permanent = PermanentMemory(self, self.config.get("knowledge_graph"))
# 检索引擎
self.search_engine = HybridSearchEngine(
vector_store=self.long_term.vector_store,
bm25_indexer=self.long_term.bm25_indexer,
knowledge_graph=self.permanent.knowledge_graph
)
# 隐私保护
self.sensitive_detector = SensitiveDetector()
self.encryption = EncryptedMemoryStorage(
key_manager=EncryptionKeyManager()
)
# 压缩器
self.compressor = SessionCompressor(
importance_scorer=ImportanceScorer(),
summarizer=MemorySummarizer()
)
# 后台任务
self._decay_thread = None
self._stop_decay = threading.Event()
def store(
self,
content: str,
level: MemoryLevel = MemoryLevel.SHORT_TERM,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""存储记忆"""
# 敏感信息检测
redacted_content, matches = self.sensitive_detector.detect_and_redact(content)
if matches:
metadata = metadata or {}
metadata["sensitive_matches"] = [
{"type": m.match_type.value, "confidence": m.confidence}
for m in matches
]
# 根据层级存储
if level == MemoryLevel.SHORT_TERM:
return self.short_term.store(redacted_content, metadata)
elif level == MemoryLevel.LONG_TERM:
return self.long_term.store(redacted_content, metadata)
else:
return self.permanent.store(redacted_content, metadata)
def retrieve(
self,
query: str,
query_vector: Optional[List[float]] = None,
level: Optional[MemoryLevel] = None,
top_k: int = 10
) -> List[MemoryItem]:
"""检索记忆"""
if level:
# 单层检索
if level == MemoryLevel.SHORT_TERM:
return self.short_term.search(query, top_k)
elif level == MemoryLevel.LONG_TERM:
return self.long_term.search(query, query_vector, top_k)
else:
return self.permanent.search(query, top_k)
else:
# 混合检索
results = self.search_engine.search(
query=query,
query_vector=query_vector or [0.0] * 1536,
top_k=top_k
)
return [self._search_result_to_memory_item(r) for r in results]
def compress_session(self, session_id: str) -> Dict[str, Any]:
"""压缩会话"""
turns = self.short_term.get_conversation_turns(session_id)
context = self.short_term.get_context_summary(session_id)
result = self.compressor.compress_session(turns, context)
# 将需要持久化的记忆转移
if result["persistent_memories"]["permanent"]:
for mem in result["persistent_memories"]["permanent"]:
self.permanent.store(mem["content"], {"session_id": session_id})
if result["persistent_memories"]["longterm"]:
for mem in result["persistent_memories"]["longterm"]:
self.long_term.store(mem["content"], {"session_id": session_id})
return result
def start(self) -> None:
"""启动记忆系统"""
self.long_term.start_decay_process()
print("Memory System started")
def stop(self) -> None:
"""停止记忆系统"""
self._stop_decay.set()
self.long_term.stop_decay_process()
print("Memory System stopped")
class ShortTermMemory:
"""短期记忆"""
def __init__(self, parent: MemorySystem):
self.parent = parent
self._storage: Dict[str, MemoryItem] = {}
self._sessions: Dict[str, List[str]] = {} # session_id -> memory_ids
def store(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> str:
memory_id = f"st_{datetime.now().timestamp()}"
item = MemoryItem(
memory_id=memory_id,
content=content,
level=MemoryLevel.SHORT_TERM,
metadata=metadata or {}
)
self._storage[memory_id] = item
session_id = metadata.get("session_id") if metadata else "default"
if session_id not in self._sessions:
self._sessions[session_id] = []
self._sessions[session_id].append(memory_id)
return memory_id
def search(self, query: str, top_k: int = 10) -> List[MemoryItem]:
# 简单实现:基于关键词匹配
results = []
query_lower = query.lower()
for item in self._storage.values():
if query_lower in item.content.lower():
item.access()
results.append(item)
return sorted(results, key=lambda x: x.last_accessed, reverse=True)[:top_k]
def get_conversation_turns(self, session_id: str) -> List[Dict]:
return []
def get_context_summary(self, session_id: str) -> Dict:
return {}
class LongTermMemory:
"""长期记忆"""
def __init__(self, parent: MemorySystem, vector_store: Optional[Any] = None):
self.parent = parent
self.vector_store = vector_store or InMemoryVectorStore()
self.bm25_indexer = BM25Indexer()
self._storage: Dict[str, MemoryItem] = {}
def store(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> str:
memory_id = f"lt_{datetime.now().timestamp()}"
item = MemoryItem(
memory_id=memory_id,
content=content,
level=MemoryLevel.LONG_TERM,
metadata=metadata or {}
)
self._storage[memory_id] = item
return memory_id
def search(self, query: str, query_vector: List[float], top_k: int = 10) -> List[MemoryItem]:
results = self.vector_store.search(query_vector, top_k)
return [self._search_result_to_memory_item(r) for r in results]
def start_decay_process(self) -> None:
pass
def stop_decay_process(self) -> None:
pass
def _search_result_to_memory_item(self, result: Any) -> MemoryItem:
return MemoryItem(
memory_id=result.chunk_id,
content=result.content,
level=MemoryLevel.LONG_TERM,
metadata=result.metadata
)
class PermanentMemory:
"""永久记忆"""
def __init__(self, parent: MemorySystem, knowledge_graph: Optional[Any] = None):
self.parent = parent
self.knowledge_graph = knowledge_graph or KnowledgeGraph()
self._storage: Dict[str, MemoryItem] = {}
def store(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> str:
memory_id = f"pm_{datetime.now().timestamp()}"
item = MemoryItem(
memory_id=memory_id,
content=content,
level=MemoryLevel.PERMANENT,
metadata=metadata or {}
)
self._storage[memory_id] = item
return memory_id
def search(self, query: str, top_k: int = 10) -> List[MemoryItem]:
entities = self.knowledge_graph.get_entities_by_name(query)
return [
MemoryItem(
memory_id=e.entity_id,
content=e.docstring or e.name,
level=MemoryLevel.PERMANENT,
metadata={"entity_type": e.entity_type.value}
)
for e in entities[:top_k]
]
def _search_result_to_memory_item(self, result: Any) -> MemoryItem:
return MemoryItem(
memory_id=result.chunk_id,
content=result.content,
level=MemoryLevel.PERMANENT,
metadata=result.metadata
)
# 辅助类(简化实现)
from abc import ABC, abstractmethod
class InMemoryVectorStore(ABC):
def __init__(self, dimension: int = 1536): pass
def insert(self, chunk): return True
def search(self, query_embedding, top_k=10, filters=None): return []
def delete(self, chunk_id): return True
class BM25Indexer:
def add_documents(self, documents, id_field="id", content_field="content"): pass
def search(self, query, top_k=10): return []
class KnowledgeGraph(ABC):
def get_entities_by_name(self, name): return []
class HybridSearchEngine:
def __init__(self, vector_store, bm25_indexer, knowledge_graph): pass
def search(self, query, query_vector, top_k=10, strategy=None, filters=None): return []
class ImportanceScorer:
def __init__(self): pass
def score(self, memory_content, context, conversation_history=None):
class SB:
novelty = accuracy = relevance = completeness = actionability = total = 0.5
return SB()
class MemorySummarizer:
def __init__(self): pass
def summarize(self, content, conversation_turns=None):
class CR:
original_length = compressed_length = 0
compression_ratio = 1.0
summary = "[压缩摘要]"
key_points = []
tokens_saved = retention_rate = 0.5
return CR()
class SessionCompressor:
def __init__(self, importance_scorer, summarizer): pass
def compress_session(self, conversation_turns, context):
return {
"compressed_turns": conversation_turns,
"was_compressed": False,
"persistent_memories": {"permanent": [], "longterm": []},
"summary": None
}
class SensitiveDetector:
def detect_and_redact(self, text): return text, []
class EncryptionKeyManager:
def __init__(self): pass
def get_data_key(self, data_id): return b'\x00' * 32
class EncryptedMemoryStorage:
def __init__(self, key_manager): pass
# 使用示例
if __name__ == "__main__":
# 创建记忆系统
memory = MemorySystem()
# 存储记忆
memory.store(
content="用户正在开发一个用户认证模块",
level=MemoryLevel.SHORT_TERM,
metadata={"session_id": "session_001", "files": ["auth.py"]}
)
# 检索记忆
results = memory.retrieve(
query="用户认证模块",
top_k=5
)
print(f"检索到 {len(results)} 条相关记忆")
# 启动系统
memory.start()关键词: AI IDE, Memory System, 短期记忆, 长期记忆, 永久记忆, 向量检索, 知识图谱, BM25, 混合搜索, 记忆压缩, 隐私保护, AES-GCM加密, 访问控制