自注意力机制让AI理解语言关系
❝在"Life is short, eat dessert first"这句话中,AI如何理解"dessert"与"eat"的关系,又为何知道"first"修饰的是整个行为而非某个具体单词?秘密就藏在自注意力机制中。
当你与ChatGPT、文心一言或通义千问对话时,是否曾好奇过:AI是如何理解你的话语并给出连贯回复的?为什么它能抓住句子中词语间的微妙关系,甚至能领会言外之意?
这背后的核心技术,正是自注意力机制(Self-Attention)——Transformer架构的"大脑引擎",也是GPT-4、Llama等大语言模型(LLMs)得以理解和生成人类语言的关键所在。
2017年,Google在《Attention Is All You Need》论文中首次提出这一革命性技术,彻底改变了自然语言处理领域。如今,它已成为AI大模型的标配,却鲜少被大众了解。今天,让我们揭开这一技术的神秘面纱。
想象一下,如果AI像早期的机器翻译系统那样逐字翻译,会发生什么?
❝"I love you because you are beautiful" 直译为 "我 爱 你 因为 你 是 美丽的"
这种翻译忽略了语言的复杂结构和上下文关系,结果往往生硬且不准确。
早期的解决方案是使用**循环神经网络(RNN)**,但RNN处理长文本时存在"记忆衰退"问题——就像我们很难记住长篇演讲中的所有细节。
自注意力机制的突破在于:让模型同时看到整个句子,并动态判断哪些词对当前理解最重要。就像你阅读时,大脑会自动关注关键信息,忽略次要内容。
注意力机制对比
上图:逐字翻译错误示例(上)与正确翻译(下)的对比
让我们用一个简单例子理解自注意力的工作原理:
**"Life is short, eat dessert first"**(生命短暂,先吃甜点)
当模型处理"dessert"(甜点)这个词时,自注意力机制会计算它与句子中所有其他词的相关性:
单词关系可视化
可视化显示单词"making"如何通过注意力权重依赖或关注输入中的其他单词
技术上,这通过三个关键步骤实现:
这一过程让模型能够"看到"词语间的复杂关系,就像人类理解语言一样。
让我们通过代码深入理解自注意力机制。首先,我们需要准备输入数据:
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
输出:
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
接下来,我们将句子转换为嵌入向量:
import torch
import torch.nn as nn
# 定义输入句子和创建词汇表字典
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
# 输出: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
# 将句子转换为整数索引表示
words = sentence.replace(',', '').split()
sentence_int = [dc[word] for word in words]
sentence_int = torch.tensor(sentence_int)
# 现在可以使用嵌入层
vocab_size = 50000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
输出:
tensor([[ 0.3374, -0.1778, -0.3035],
[ 0.1794, 1.8951, 0.4954],
[ 0.2692, -0.0770, -1.0205],
[-0.2196, -0.3792, 0.7671],
[-0.5880, 0.3486, 0.6603],
[-1.1925, 0.6984, -1.4097]])
torch.Size([6, 3])
现在,定义权重矩阵:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))
计算第二个输入元素的注意力向量:
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)
输出:
torch.Size([2])
torch.Size([2])
torch.Size([4])
计算所有输入的键和值:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
输出:
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])
计算未归一化的注意力权重:
omega_2 = query_2 @ keys.T
print(omega_2)
输出:
tensor([-0.6004, 3.4707, -1.5023, 0.4991, 1.2903, -1.3374])
计算注意力权重:
import torch.nn.functional as F
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)
输出:
tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229])
最后,计算上下文向量:
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)
输出:
torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110])
将以上代码封装为自定义类:
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T # 未归一化的注意力权重
attn_weights = torch.softmax(
attn_scores / self.d_out_kq**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec
使用该类:
torch.manual_seed(123)
# 将d_out_v从4减少到1,因为我们有4个头
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))
输出:
tensor([[-0.1564, 0.1028, -0.0763, -0.0764],
[ 0.5313, 1.3607, 0.7891, 1.3110],
[-0.3542, -0.1234, -0.2627, -0.3706],
[ 0.0071, 0.3345, 0.0969, 0.1998],
[ 0.1008, 0.4780, 0.2021, 0.3674],
[-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)
单一的自注意力就像用一只眼睛看世界——虽然能看到关系,但视角有限。
多头注意力则如同给AI装上了"多双眼睛",从不同角度分析同一句话:
实现多头注意力:
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
super().__init__()
self.heads = nn.ModuleList(
[SelfAttention(d_in, d_out_kq, d_out_v)
for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
测试多头注意力:
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 1
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))
torch.manual_seed(123)
block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
d_in, d_out_kq, d_out_v, num_heads=4
)
context_vecs = mha(embedded_sentence)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
输出:
tensor([[-0.0185],
[ 0.4003],
[-0.1103],
[ 0.0668],
[ 0.1180],
[-0.1827]], grad_fn=<MmBackward0>)
tensor([[-0.0185, 0.0170, 0.1999, -0.0860],
[ 0.4003, 1.7137, 1.3981, 1.0497],
[-0.1103, -0.1609, 0.0079, -0.2416],
[ 0.0668, 0.3534, 0.2322, 0.1008],
[ 0.1180, 0.6949, 0.3157, 0.2807],
[-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])
当你使用ChatGPT时,它不会一次性生成整段回复,而是逐字生成,就像我们说话一样。
这就需要因果自注意力(也称掩码自注意力)——确保模型在预测下一个词时,只能"看到"前面的词,而不能"偷看"未来的内容。
首先,计算注意力分数:
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
W_value = nn.Parameter(torch.rand(d_in, d_out_v))
x = embedded_sentence
keys = x @ W_key
queries = x @ W_query
values = x @ W_value
# attn_scores是"omegas",
# 未归一化的注意力权重
attn_scores = queries @ keys.T
print(attn_scores)
print(attn_scores.shape)
输出:
tensor([[ 0.0613, -0.3491, 0.1443, -0.0437, -0.1303, 0.1076],
[-0.6004, 3.4707, -1.5023, 0.4991, 1.2903, -1.3374],
[ 0.2432, -1.3934, 0.5869, -0.1851, -0.5191, 0.4730],
[-0.0794, 0.4487, -0.1807, 0.0518, 0.1677, -0.1197],
[-0.1510, 0.8626, -0.3597, 0.1112, 0.3216, -0.2787],
[ 0.4344, -2.5037, 1.0740, -0.3509, -0.9315, 0.9265]],
grad_fn=<MmBackward0>)
torch.Size([6, 6])
计算注意力权重:
attn_weights = torch.softmax(attn_scores / d_out_kq**0.5, dim=1)
print(attn_weights)
输出:
tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
[0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
[0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
[0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],
[0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],
[0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
grad_fn=<SoftmaxBackward0>)
应用掩码:
block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)
masked_simple = attn_weights*mask_simple
print(masked_simple)
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)
输出:
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])
tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],
[0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],
[0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],
[0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
grad_fn=<MulBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
[0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
[0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
[0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
grad_fn=<DivBackward0>)
更高效的实现方法:
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)
attn_weights = torch.softmax(masked / d_out_kq**0.5, dim=1)
print(attn_weights)
输出:
tensor([[ 0.0613, -inf, -inf, -inf, -inf, -inf],
[-0.6004, 3.4707, -inf, -inf, -inf, -inf],
[ 0.2432, -1.3934, 0.5869, -inf, -inf, -inf],
[-0.0794, 0.4487, -0.1807, 0.0518, -inf, -inf],
[-0.1510, 0.8626, -0.3597, 0.1112, 0.3216, -inf],
[ 0.4344, -2.5037, 1.0740, -0.3509, -0.9315, 0.9265]],
grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
[0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
[0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
[0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
grad_fn=<SoftmaxBackward0>)
自注意力机制的价值远超聊天机器人:
更重要的是,它代表了一种通用的关系建模能力——不仅适用于语言,还可用于图像、音频甚至科学数据的处理。
尽管强大,自注意力机制也有局限:
为此,研究者们正在探索更高效的变体,如FlashAttention和线性注意力机制,以突破这些限制。
自注意力机制不仅是技术细节,更代表了AI理解世界的一种新范式——通过分析元素间的关系来获取意义。
当我们与大模型对话时,背后是无数"注意力头"在默默工作,分析词语间的千丝万缕,试图捕捉人类语言的精髓。
下一次,当你惊叹于AI的"聪明"时,不妨想想这个精妙的机制——它让机器第一次真正"理解"了我们的语言,尽管这种理解仍与人类的意识相去甚远。
正如一位研究者所言:"自注意力不是魔法,但它是通往智能的重要一步。"
❝思考题:如果AI能通过自注意力理解语言关系,它是否也能理解人与人之间的情感联系?这种理解与人类的理解有何本质区别?欢迎在评论区分享你的见解!
参考资料:本文基于Sebastian Raschka的《Understanding and Coding Self-Attention》技术文章改编,保留了所有关键代码实现细节,适合对AI技术原理感兴趣的读者阅读。