本文是对GRU的精简介绍,对于初学者可以看详细介绍:https://zh.d2l.ai/chapter_recurrent-modern/gru.html
GRU (Gate Recurrent Unit ) 背后的原理与 LSTM 非常相似,即用门控机制控制输入、记忆等信息而在当前时间步做出预测。
GRU 有两个门,即一个重置门(reset gate)和一个更新门(update gate)。从直观上来说,「重置门决定了如何将新的输入信息与前面的记忆相结合,更新门定义了前面记忆保存到当前时间步的量」。如果我们将重置门设置为 1,更新门设置为 0,那么我们将再次获得标准 RNN 模型。
GRU 原论文:https://arxiv.org/pdf/1406.1078v3.pdf
⊙ 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 ⊕ 则代表进行矩阵加法操作。
这个隐状态包含了之前节点的相关信息。
根据输入获取重置的门控(reset gate)和 控制更新的门控(update gate)
为*sigmoid*函数,通过这个函数可以将数据变换为0-1范围内的数值,从而来充当门控信号。
如何根据门控重置数据
其中的
根据下面的公式获取:
⊙
class testGRU(nn.Module):
def __init__(self, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes, sequence_length=sequence_length):
super(SimpleGRU, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
self.fc1 = nn.Linear(hidden_size * sequence_length, num_classes)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
out,_ = self.gru(x, h0)
out = out.reshape(out.shape[0], -1)
out = self.fc1(out)
return out