首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >AI如何理解语言?自注意力机制的技术原理与代码实践

AI如何理解语言?自注意力机制的技术原理与代码实践

作者头像
AI浩
发布2025-08-14 15:14:13
发布2025-08-14 15:14:13
14000
代码可运行
举报
文章被收录于专栏:AI智韵AI智韵
运行总次数:0
代码可运行
自注意力机制让AI理解语言关系
自注意力机制让AI理解语言关系

自注意力机制让AI理解语言关系

在"Life is short, eat dessert first"这句话中,AI如何理解"dessert"与"eat"的关系,又为何知道"first"修饰的是整个行为而非某个具体单词?秘密就藏在自注意力机制中。

一、大模型的"大脑秘密"

当你与ChatGPT、文心一言或通义千问对话时,是否曾好奇过:AI是如何理解你的话语并给出连贯回复的?为什么它能抓住句子中词语间的微妙关系,甚至能领会言外之意?

这背后的核心技术,正是自注意力机制(Self-Attention)——Transformer架构的"大脑引擎",也是GPT-4、Llama等大语言模型(LLMs)得以理解和生成人类语言的关键所在。

2017年,Google在《Attention Is All You Need》论文中首次提出这一革命性技术,彻底改变了自然语言处理领域。如今,它已成为AI大模型的标配,却鲜少被大众了解。今天,让我们揭开这一技术的神秘面纱。

二、从"逐字翻译"到"理解上下文":自注意力的前世今生

想象一下,如果AI像早期的机器翻译系统那样逐字翻译,会发生什么?

"I love you because you are beautiful" 直译为 "我 爱 你 因为 你 是 美丽的"

这种翻译忽略了语言的复杂结构和上下文关系,结果往往生硬且不准确。

早期的解决方案是使用**循环神经网络(RNN)**,但RNN处理长文本时存在"记忆衰退"问题——就像我们很难记住长篇演讲中的所有细节。

自注意力机制的突破在于:让模型同时看到整个句子,并动态判断哪些词对当前理解最重要。就像你阅读时,大脑会自动关注关键信息,忽略次要内容。

注意力机制对比
注意力机制对比

注意力机制对比

上图:逐字翻译错误示例(上)与正确翻译(下)的对比

三、自注意力:AI的"关系网络分析仪"

让我们用一个简单例子理解自注意力的工作原理:

**"Life is short, eat dessert first"**(生命短暂,先吃甜点)

当模型处理"dessert"(甜点)这个词时,自注意力机制会计算它与句子中所有其他词的相关性:

  • "dessert"与"eat"(吃)有强关联
  • "dessert"与"first"(首先)也有一定联系
  • 但与"Life"(生命)的关联则较弱
单词关系可视化
单词关系可视化

单词关系可视化

可视化显示单词"making"如何通过注意力权重依赖或关注输入中的其他单词

技术上,这通过三个关键步骤实现:

  1. 查询-键-值转换:将每个词转化为"查询"(Query)、"键"(Key)和"值"(Value)向量
  2. 计算注意力权重:确定每个词对其他词的重要性
  3. 加权求和:根据权重组合信息,生成包含上下文的新表示

这一过程让模型能够"看到"词语间的复杂关系,就像人类理解语言一样。

四、动手实践:用代码实现自注意力

让我们通过代码深入理解自注意力机制。首先,我们需要准备输入数据:

代码语言:javascript
代码运行次数:0
运行
复制
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}

接下来,我们将句子转换为嵌入向量:

代码语言:javascript
代码运行次数:0
运行
复制
import torch
import torch.nn as nn
# 定义输入句子和创建词汇表字典
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
# 输出: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
# 将句子转换为整数索引表示
words = sentence.replace(',', '').split()
sentence_int = [dc[word] for word in words]
sentence_int = torch.tensor(sentence_int)
# 现在可以使用嵌入层
vocab_size = 50000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])

现在,定义权重矩阵:

代码语言:javascript
代码运行次数:0
运行
复制
torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 2, 2, 4
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

计算第二个输入元素的注意力向量:

代码语言:javascript
代码运行次数:0
运行
复制
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([2])
torch.Size([2])
torch.Size([4])

计算所有输入的键和值:

代码语言:javascript
代码运行次数:0
运行
复制
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])

计算未归一化的注意力权重:

代码语言:javascript
代码运行次数:0
运行
复制
omega_2 = query_2 @ keys.T
print(omega_2)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374])

计算注意力权重:

代码语言:javascript
代码运行次数:0
运行
复制
import torch.nn.functional as F
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229])

最后,计算上下文向量:

代码语言:javascript
代码运行次数:0
运行
复制
context_vector_2 = attention_weights_2 @ values
print(context_vector_2.shape)
print(context_vector_2)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110])

将以上代码封装为自定义类:

代码语言:javascript
代码运行次数:0
运行
复制
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # 未归一化的注意力权重    
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

使用该类:

代码语言:javascript
代码运行次数:0
运行
复制
torch.manual_seed(123)
# 将d_out_v从4减少到1,因为我们有4个头
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[-0.1564,  0.1028, -0.0763, -0.0764],
        [ 0.5313,  1.3607,  0.7891,  1.3110],
        [-0.3542, -0.1234, -0.2627, -0.3706],
        [ 0.0071,  0.3345,  0.0969,  0.1998],
        [ 0.1008,  0.4780,  0.2021,  0.3674],
        [-0.5296, -0.2799, -0.4107, -0.6006]], grad_fn=<MmBackward0>)

五、多头注意力:AI的"多维视角"

单一的自注意力就像用一只眼睛看世界——虽然能看到关系,但视角有限。

多头注意力则如同给AI装上了"多双眼睛",从不同角度分析同一句话:

  • 一个"头"可能专注于语法结构
  • 另一个"头"可能关注情感色彩
  • 还有一个"头"可能识别实体关系

实现多头注意力:

代码语言:javascript
代码运行次数:0
运行
复制
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [SelfAttention(d_in, d_out_kq, d_out_v) 
             for _ in range(num_heads)]
        )
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

测试多头注意力:

代码语言:javascript
代码运行次数:0
运行
复制
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 1
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

torch.manual_seed(123)
block_size = embedded_sentence.shape[1]
mha = MultiHeadAttentionWrapper(
    d_in, d_out_kq, d_out_v, num_heads=4
)
context_vecs = mha(embedded_sentence)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[-0.0185],
        [ 0.4003],
        [-0.1103],
        [ 0.0668],
        [ 0.1180],
        [-0.1827]], grad_fn=<MmBackward0>)

tensor([[-0.0185,  0.0170,  0.1999, -0.0860],
        [ 0.4003,  1.7137,  1.3981,  1.0497],
        [-0.1103, -0.1609,  0.0079, -0.2416],
        [ 0.0668,  0.3534,  0.2322,  0.1008],
        [ 0.1180,  0.6949,  0.3157,  0.2807],
        [-0.1827, -0.2060, -0.2393, -0.3167]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([6, 4])

六、因果自注意力:大模型如何"一步步"思考

当你使用ChatGPT时,它不会一次性生成整段回复,而是逐字生成,就像我们说话一样。

这就需要因果自注意力(也称掩码自注意力)——确保模型在预测下一个词时,只能"看到"前面的词,而不能"偷看"未来的内容。

首先,计算注意力分数:

代码语言:javascript
代码运行次数:0
运行
复制
torch.manual_seed(123)
d_in, d_out_kq, d_out_v = 3, 2, 4
W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
W_key   = nn.Parameter(torch.rand(d_in, d_out_kq))
W_value = nn.Parameter(torch.rand(d_in, d_out_v))
x = embedded_sentence
keys = x @ W_key
queries = x @ W_query
values = x @ W_value
# attn_scores是"omegas",
# 未归一化的注意力权重
attn_scores = queries @ keys.T 
print(attn_scores)
print(attn_scores.shape)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[ 0.0613, -0.3491,  0.1443, -0.0437, -0.1303,  0.1076],
        [-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
        [ 0.2432, -1.3934,  0.5869, -0.1851, -0.5191,  0.4730],
        [-0.0794,  0.4487, -0.1807,  0.0518,  0.1677, -0.1197],
        [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216, -0.2787],
        [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
       grad_fn=<MmBackward0>)
torch.Size([6, 6])

计算注意力权重:

代码语言:javascript
代码运行次数:0
运行
复制
attn_weights = torch.softmax(attn_scores / d_out_kq**0.5, dim=1)
print(attn_weights)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[0.1772, 0.1326, 0.1879, 0.1645, 0.1547, 0.1831],
        [0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
        [0.1965, 0.0618, 0.2506, 0.1452, 0.1146, 0.2312],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.1793, 0.1463],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.1231],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)

应用掩码:

代码语言:javascript
代码运行次数:0
运行
复制
block_size = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

masked_simple = attn_weights*mask_simple
print(masked_simple)

row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

tensor([[0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0386, 0.6870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1965, 0.0618, 0.2506, 0.0000, 0.0000, 0.0000],
        [0.1505, 0.2187, 0.1401, 0.1651, 0.0000, 0.0000],
        [0.1347, 0.2758, 0.1162, 0.1621, 0.1881, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<MulBackward0>)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<DivBackward0>)

更高效的实现方法:

代码语言:javascript
代码运行次数:0
运行
复制
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

attn_weights = torch.softmax(masked / d_out_kq**0.5, dim=1)
print(attn_weights)

输出:

代码语言:javascript
代码运行次数:0
运行
复制
tensor([[ 0.0613,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.6004,  3.4707,    -inf,    -inf,    -inf,    -inf],
        [ 0.2432, -1.3934,  0.5869,    -inf,    -inf,    -inf],
        [-0.0794,  0.4487, -0.1807,  0.0518,    -inf,    -inf],
        [-0.1510,  0.8626, -0.3597,  0.1112,  0.3216,    -inf],
        [ 0.4344, -2.5037,  1.0740, -0.3509, -0.9315,  0.9265]],
       grad_fn=<MaskedFillBackward0>)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0532, 0.9468, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3862, 0.1214, 0.4924, 0.0000, 0.0000, 0.0000],
        [0.2232, 0.3242, 0.2078, 0.2449, 0.0000, 0.0000],
        [0.1536, 0.3145, 0.1325, 0.1849, 0.2145, 0.0000],
        [0.1973, 0.0247, 0.3102, 0.1132, 0.0751, 0.2794]],
       grad_fn=<SoftmaxBackward0>)

七、自注意力的现实意义:不止于聊天机器人

自注意力机制的价值远超聊天机器人:

  • 精准搜索:理解查询与文档的深层语义关联
  • 智能写作:保持文章逻辑连贯、风格一致
  • 代码生成:理解编程语言的结构和上下文
  • 多语言翻译:捕捉语言间的复杂对应关系

更重要的是,它代表了一种通用的关系建模能力——不仅适用于语言,还可用于图像、音频甚至科学数据的处理。

八、局限与未来:自注意力的挑战

尽管强大,自注意力机制也有局限:

  • 计算成本高:处理长文本时计算量呈平方级增长
  • 缺乏真正的理解:模型只是学习统计模式,而非真正"理解"含义
  • 上下文长度限制:当前模型通常只能处理几千到几万token

为此,研究者们正在探索更高效的变体,如FlashAttention和线性注意力机制,以突破这些限制。

结语:理解AI,从理解其"大脑"开始

自注意力机制不仅是技术细节,更代表了AI理解世界的一种新范式——通过分析元素间的关系来获取意义

当我们与大模型对话时,背后是无数"注意力头"在默默工作,分析词语间的千丝万缕,试图捕捉人类语言的精髓。

下一次,当你惊叹于AI的"聪明"时,不妨想想这个精妙的机制——它让机器第一次真正"理解"了我们的语言,尽管这种理解仍与人类的意识相去甚远。

正如一位研究者所言:"自注意力不是魔法,但它是通往智能的重要一步。"

❝思考题:如果AI能通过自注意力理解语言关系,它是否也能理解人与人之间的情感联系?这种理解与人类的理解有何本质区别?欢迎在评论区分享你的见解!


参考资料:本文基于Sebastian Raschka的《Understanding and Coding Self-Attention》技术文章改编,保留了所有关键代码实现细节,适合对AI技术原理感兴趣的读者阅读。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-08-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI智韵 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、大模型的"大脑秘密"
  • 二、从"逐字翻译"到"理解上下文":自注意力的前世今生
  • 三、自注意力:AI的"关系网络分析仪"
  • 四、动手实践:用代码实现自注意力
  • 五、多头注意力:AI的"多维视角"
  • 六、因果自注意力:大模型如何"一步步"思考
  • 七、自注意力的现实意义:不止于聊天机器人
  • 八、局限与未来:自注意力的挑战
  • 结语:理解AI,从理解其"大脑"开始
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档