import tensorflow as tf
from transformerx.layers.positional_encoding import SinePositionalEncoding
from transformerx.layers.transformer_encoder_block import TransformerEncoderBlock
class TransformerEncoder(tf.keras.layers.Layer):
def __init__(self,vocab_size,depth,norm_shape,ffn_num_hiddens,
num_heads,
n_blocks,
dropout,
bias=False,
):
super().__init__()
self.depth = depth
self.n_blocks = n_blocks
self.embedding = tf.keras.layers.Embedding(vocab_size, depth)
self.pos_encoding = SinePositionalEncoding(depth, dropout)
self.blocks = [
TransformerEncoderBlock(
depth,
norm_shape,
ffn_num_hiddens,
num_heads,
dropout,
bias,
)
for _ in range(self.n_blocks)
]
def call(self, X, valid_lens, **kwargs):
X = self.pos_encoding(
self.embedding(X) * tf.math.sqrt(tf.cast(self.depth, dtype=tf.float32)),
**kwargs,
)
self.attention_weights = [None] * len(self.blocks)
for i, blk in enumerate(self.blocks):
X = blk(X, valid_lens, **kwargs)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X
import tensorflow as tf
from transformerx.layers.positional_encoding import SinePositionalEncoding
from transformerx.layers.transformer_decoder_block import TransformerDecoderBlock
class TransformerDecoder(tf.keras.layers.Layer):
def __init__(self,vocab_size,depth,norm_shape,ffn_num_hiddens,num_heads,n_blocks,dropout,):
super().__init__()
self.depth = depth
self.n_blocks = n_blocks
self.embedding = tf.keras.layers.Embedding(vocab_size, depth)
self.pos_encoding = SinePositionalEncoding(depth, dropout)
self.blocks = [
TransformerDecoderBlock(
depth,
norm_shape,
ffn_num_hiddens,
num_heads,
dropout,
i,
)
for i in range(n_blocks)
]
self.dense = tf.keras.layers.Dense(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.n_blocks]
def call(self, X, state, **kwargs):
X = self.pos_encoding(
self.embedding(X) * tf.math.sqrt(tf.cast(self.depth, dtype=tf.float32)),
**kwargs,
)
# 2 attention layers in decoder
self._attention_weights = [[None] * len(self.blocks) for _ in range(2)]
for i, blk in enumerate(self.blocks):
X, state = blk(X, state, **kwargs)
# Decoder self-attention weights
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
该Transformer将信息检索中的查询键值(QKV)概念与注意力机制相结合
矩阵𝑨 在等式1中,通常称为注意力矩阵。他们使用点积注意力而不是加法注意力(使用具有单个隐藏层的前馈网络来计算兼容性函数)的原因是,由于矩阵乘法优化技术,速度和空间效率更快。
尽管如此,对于较大值的𝐷𝑘 这将softmax函数的梯度推到极小的梯度。为了抑制softmax函数的梯度消失问题,将键和查询的点积除以𝐷𝑘, 由于这个事实,它被称为缩放点积。
import os
import numpy as np
import tensorflow as tf
from transformerx.utils import masked_softmax
class DotProductAttention(tf.keras.layers.Layer):
def __init__(
self,
dropout_rate: float = 0,
scaled: bool = True,
normalize: bool = False,
kernel_initializer: str = "ones",
kernel_regularizer: str = None,
**kwargs
):
super().__init__(**kwargs)
self.dropout_rate = dropout_rate
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.scaled = scaled
self.normalize = normalize
self.attention_weights = None
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
def build(self, input_shape):
super().build(input_shape)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of attention_mask: (batch_size,) or (batch_size, no. of queries)
def call(
self,
queries: tf.Tensor,
keys: tf.Tensor,
values: tf.Tensor,
attention_mask: tf.Tensor = None,
causal_mask: bool = None,
training=None,
**kwargs
) -> tf.Tensor:
scores = tf.matmul(queries, keys, transpose_b=True)
if self.scaled:
# self.scale = self.add_weight(
# name="scale",
# shape=(scores.shape),
# initializer=self.kernel_initializer,
# regularizer=self.kernel_regularizer,
# trainable=True,
# )
depth = queries.shape[-1]
# print(self.scale, scores.shape)
# self.scale = tf.broadcast_to(scores.shape)
# self.scale = tf.broadcast_to(
# tf.expand_dims(tf.expand_dims(self.scale, -1), -1), scores.shape
# )
scores = (
scores
/ tf.math.sqrt(tf.cast(depth, dtype=tf.float32))
# * self.scale
)
# apply causal mask
if causal_mask:
seq_len = tf.shape(queries)[2]
heads = tf.shape(queries)[1]
causal_mask = tf.ones((heads, seq_len)) * -1e9
causal_mask = tf.linalg.LinearOperatorLowerTriangular(
causal_mask
).to_dense()
causal_mask = tf.expand_dims(causal_mask, axis=0) # add batch dimension
scores += tf.broadcast_to(
tf.expand_dims(causal_mask, -1), scores.shape
) # broadcast across batch dimension
self.attention_weights = masked_softmax(scores, attention_mask)
# self.attention_weights = tf.nn.softmax(scores, axis=-1, mask=attention_mask)
scores = tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
if self.normalize:
depth = tf.cast(tf.shape(keys)[-1], tf.float32)
scores /= tf.sqrt(depth)
return scores
def get_attention_weights(self):
return self.attention_weights
import numpy as np
import tensorflow as tf
from einops import rearrange
from transformerx.layers.dot_product_attention import DotProductAttention
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(
self,
d_model: int = 512,
num_heads: int = 8,
dropout_rate: float = 0,
bias: bool = False,
attention: str = "scaled_dotproduct",
**kwargs,
):
super(MultiHeadAttention, self).__init__(**kwargs)
self.d_model = d_model
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.bias = bias
if attention == "scaled_dotproduct" or attention == None:
self.attention = DotProductAttention(self.dropout_rate, scaled=True)
elif attention == "dotproduct":
self.attention = DotProductAttention(self.dropout_rate, scaled=False)
self.W_q = tf.keras.layers.Dense(self.d_model, use_bias=self.bias)
self.W_k = tf.keras.layers.Dense(self.d_model, use_bias=self.bias)
self.W_v = tf.keras.layers.Dense(self.d_model, use_bias=self.bias)
self.W_o = tf.keras.layers.Dense(self.d_model, use_bias=self.bias)
def split_heads(self, X: tf.Tensor) -> tf.Tensor:
# x = tf.reshape(x, shape=(x.shape[0], x.shape[1], self.num_heads, -1))
X = rearrange(X, "b l (h dk) -> b l h dk", h=self.num_heads)
# x = tf.transpose(x, perm=(0, 2, 1, 3))
X = rearrange(X, "b l h dk -> b h l dk")
# return tf.reshape(x, shape=(-1, x.shape[2], x.shape[3]))
# X = rearrange(X, "b h l dk -> (b h) l dk")
return X
def inverse_transpose_qkv(self, X: tf.Tensor) -> tf.Tensor:
# transpose back to original shape: (batch_size, seq_len, num_heads, head_dim)
X = rearrange(X, "b h l d -> b l h d")
# concatenate num_heads dimension with head_dim dimension:
X = rearrange(X, "b l h d -> b l (h d)")
return X
def call(
self,
queries: tf.Tensor,
values: tf.Tensor,
keys: tf.Tensor,
attention_mask: tf.Tensor = None,
causal_mask: bool = False,
**kwargs,
) -> tf.Tensor:
queries = self.split_heads(self.W_q(queries))
keys = self.split_heads(self.W_k(keys))
values = self.split_heads(self.W_v(values))
if attention_mask is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
attention_mask = tf.repeat(attention_mask, repeats=self.num_heads, axis=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# depth / num_heads)
output = self.attention(
queries, keys, values, attention_mask, causal_mask, **kwargs
)
# Shape of output_concat: (batch_size, no. of queries, depth)
output_concat = self.inverse_transpose_qkv(output)
return self.W_o(output_concat)