self-attention机制的实现步骤
表示有3个token(可以是单词、句子)
self-attention 机制会在这 3 个 token 之间计算注意力分数,从而让每个 token 能够关注到其他 token 的信息。
import torch
x = [
[1, 0, 1, 0], # Input 1
[0, 2, 0, 2], # Input 2
[1, 1, 1, 1] # Input 3
]
x = torch.tensor(x)
w_K、w_Q、w_V后期随着训练而更新
x = torch.tensor(x, dtype=torch.float32)
# 初始化参数
# 每一个输入都有三个表示,分别为key、query、value
# 初始化权重向量
w_K = [
[0, 0, 1],
[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
]
w_Q = [
[1, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 1]
]
w_V = [
[0, 2, 0],
[0, 3, 0],
[1, 0, 3],
[1, 1, 0]
]
w_key = torch.tensor(w_K, dtype=torch.float32)
w_query = torch.tensor(w_Q, dtype=torch.float32)
w_value = torch.tensor(w_V, dtype=torch.float32)
[0, 0, 1]
[1, 0, 1, 0] [1, 1, 0] [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1] [1, 1, 0] [2, 3, 1]
输入词向量*w_K权重矩阵,得到了keys
# 将query key value分别进行计算
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
print("Keys: \n", keys)
print("Querys: \n", querys)
print("Values: \n", values)
attn_scores = querys @ keys.T
print(attn_scores)
from torch.nn.functional import softmax
attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
attn_scores_softmax = [
[0.0, 0.5, 0.5],
[0.0, 1.0, 0.0],
[0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)
这里由于归一化值比较乱,就大概赋值为新的attn_scores_softmax了。
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)
# 给value加权求和
# output就是自注意力的输出
output1 = weighted_values.sum(dim=0)
print('output1:', output1)
output2 = weighted_values.sum(dim=1)
print('output2:', output2)
output3 = weighted_values.sum(dim=2)
print('output3:', output3)
解释输出
在 self-attention 机制中,weighted_values 是经过注意力权重加权后的 value 向量。
attn_scores_softmax:通过 softmax 函数归一化后的注意力分数矩阵,表示每个 token 对其他 token 的关注度。
比如在机器翻译中,weighted_values 帮助模型在翻译过程中关注源语言句子中的不同部分;例如,在翻译 "The cat is on the mat" 时,模型可以通过weighted_values 更好地理解 "cat" 和 "mat" 之间的关系,并生成更准确的目标语言句子。
在实际任务中,output 通常是 output1,即每个 token 经过自注意力机制后的新表示,这个新表示综合了该 token 对其他所有 token 的关注度及其对应的 value 信息。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。