Multi-layer Graph Convolutional Network (GCN) with first-order filters,来源:http://tkipf.github.io/graph-convolutional-networks/
本文完整代码和数据已经上传到Github,希望大家不吝赐教,感谢! https://github.com/YoungTimes/GNN/tree/master/GCN
图卷积神经网络(Graph Convolution Networks, GCN)跟CNN一样是特征提取的工具,CNN在处理规则数据结构(如图片等)方面非常强大。
图像矩阵示意图(Euclidean Structure),图片来源【4】
但是在现实世界中,很多数据结构是不规则的,典型的就是图结构,如社交网络、知识图谱等,GNN就比较擅长处理这类数据。
社交网络拓扑示意(Non Euclidean Structure),图片来源【4】
本文主要通过一个完整的GCN对论文进行分类的例子,来展示GCN的工作过程和原理。
这里我们使用的Cora数据集,该数据集由2708篇论文的特征、分类以及它们之间引用关系的5429条边组成,这些论文的类型被划分为7个类别:Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。
最终实现的目标是,输入一篇论文的特征,就可以输出该论文属于哪个分类。
https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
Cora Dataset是对Machine Learning Paper进行分类的数据集,它包含三个文件:
-- README: 对数据集的介绍;
-- cora.cites: 论文之间的引用关系图。文件中每行包含两个Paper ID, 第一个ID是被引用的Paper ID;第二个是引用的Paper ID。格式如下:
-- cora.content: 包含了2708篇Paper的信息,每行的数据格式如下: <paper_id> <word_attributes>+ <class_label>。paper id是论文的唯一标识;word_attributes是是一个维度为1433的词向量,词向量的每个元素对应一个词,0表示该元素对应的词不在Paper中,1表示该元素对应的词在Paper中。class_label是论文的类别,每篇Paper被映射到如下7个分类之一: Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。
先看看Cora Dataset中的数据是什么样的...
import pandas as pd
import numpy as np
# 导入数据:分隔符为Tab
raw_data_content = pd.read_csv('data/cora/cora.content',sep = '\t',header = None)
# [2708 * 1435]
(row, col) = raw_data_content.shape
print("Cora Contents’s Row: {}, Col: {}".format(row, col))
print("=============================================")
# 每行是1435维的向量,第一维是论文的ID,最后一维是论文的Label
raw_data_sample = raw_data_content.head(3)
features_sample =raw_data_sample.iloc[:,1:-1]
labels_sample = raw_data_sample.iloc[:, -1]
labels_onehot_sample = pd.get_dummies(labels_sample)
print("features:{}".format(features_sample))
print("=============================================")
print("labels:{}".format(labels_sample))
print("=============================================")
print("labels one hot:{}".format(labels_onehot_sample))
Cora Contents’s Row: 2708, Col: 1435
=============================================
features: 1 2 3 4 5 6 7 8 9 10 ... 1424 \
0 0 0 0 0 0 0 0 0 0 0 ... 0
1 0 0 0 0 0 0 0 0 0 0 ... 0
2 0 0 0 0 0 0 0 0 0 0 ... 0
1425 1426 1427 1428 1429 1430 1431 1432 1433
0 0 0 1 0 0 0 0 0 0
1 0 1 0 0 0 0 0 0 0
2 0 0 0 0 0 0 0 0 0
[3 rows x 1433 columns]
=============================================
labels:0 Neural_Networks
1 Rule_Learning
2 Reinforcement_Learning
Name: 1434, dtype: object
=============================================
labels one hot: Neural_Networks Reinforcement_Learning Rule_Learning
0 1 0 0
1 0 0 1
2 0 1 0
raw_data_cites = pd.read_csv('data/cora/cora.cites',sep = '\t',header = None)
# [5429 * 2]
(row, col) = raw_data_cites.shape
print("Cora Cites’s Row: {}, Col: {}".format(row, col))
print("=============================================")
raw_data_cites_sample = raw_data_cites.head(10)
print(raw_data_cites_sample)
print("=============================================")
# raw_data_cites.head(10).values.flatten().tolist()
# Convert Cite to adj matrix
idx = np.array(raw_data_content.iloc[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}
edge_indexs = np.array(list(map(idx_map.get, raw_data_cites.values.flatten())), dtype=np.int32)
edge_indexs = edge_indexs.reshape(raw_data_cites.shape)
adjacency = sp.coo_matrix((np.ones(len(edge_indexs)),
(edge_indexs[:, 0], edge_indexs[:, 1])),
shape=(edge_indexs.shape[0], edge_indexs.shape[0]), dtype="float32")
print(adjacency)
Cora Cites’s Row: 5429, Col: 2
0 1
0 35 1033
1 35 103482
2 35 103515
3 35 1050679
4 35 1103960
... ... ...
5424 853116 19621
5425 853116 853155
5426 853118 1140289
5427 853155 853118
5428 954315 1155073
[5429 rows x 2 columns]
=============================================
(163, 402) 1.0
(163, 659) 1.0
(163, 1696) 1.0
(163, 2295) 1.0
(163, 1274) 1.0
(163, 1286) 1.0
(163, 1544) 1.0
(163, 2600) 1.0
(163, 2363) 1.0
(163, 1905) 1.0
(163, 1611) 1.0
(163, 141) 1.0
(163, 1807) 1.0
(163, 1110) 1.0
(163, 174) 1.0
(163, 2521) 1.0
(163, 1792) 1.0
(163, 1675) 1.0
(163, 1334) 1.0
(163, 813) 1.0
(163, 1799) 1.0
(163, 1943) 1.0
(163, 2077) 1.0
(163, 765) 1.0
(163, 769) 1.0
: :
(2228, 1093) 1.0
(2228, 1094) 1.0
(2228, 2068) 1.0
(2228, 2085) 1.0
(2694, 2331) 1.0
(617, 226) 1.0
(422, 1691) 1.0
(2142, 2096) 1.0
(1477, 1252) 1.0
(1485, 1252) 1.0
(2185, 2109) 1.0
(2117, 2639) 1.0
(1211, 1247) 1.0
(1884, 745) 1.0
(1884, 1886) 1.0
(1884, 1902) 1.0
(1885, 745) 1.0
(1885, 1884) 1.0
(1885, 1886) 1.0
(1885, 1902) 1.0
(1886, 745) 1.0
(1886, 1902) 1.0
(1887, 2258) 1.0
(1902, 1887) 1.0
(837, 1686) 1.0
这里使用[0, 150)个数据作为训练集合,[150, 500)个数据作为验证集,[500, 2708)个数据作为测试集,实现上使用掩码(train_mask、val_mask、test_mask)的形式来区分训练集、验证集和测试集。
train_index = np.arange(150)
val_index = np.arange(150, 500)
test_index = np.arange(500, 2708)
train_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)
val_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)
test_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)
train_mask[train_index] = True
val_mask[val_index] = True
test_mask[test_index] = True
图(Graph)其实数据结构中最重要的概念之一,对,没错,图神经网的图(Graph)跟数据结构中的图(Graph)是一回事。假设神经网络的输入图(Graph)包含N个节点(Node),每个节点有d个特征,则所有这些节点的特征组成一个Nxd维的矩阵X;两个节点间的邻接关系组成一个NxN的邻接矩阵A(adjacency),则X和A就构成了图神经网络的输入。
如下所示的5个Node组成的无向图的图结构。
它对应的邻接矩阵A为:
在邻接矩阵基础上加上自环,即与单位矩阵I相加,得到:
简单期间,这里假设每个Graph中每个Node的特征都是一维的,
,
从上式看到了什么,对,就是各个节点都将与其一阶相邻节点的信息融合到自己的节点中,这也是公式(1)的本质所在。神经网络传播的过程,实际上就是图(Graph)中各个节点(Node)不断聚合邻居节点信息的过程。 通过两次连乘
也就实现各个Node融合自己二阶邻居节点的信息。
到这里我们应该就可以理解为什么要将邻接矩阵加上单位矩阵。
因为邻接矩阵的对角线的值都是0,所以如果用邻接矩阵直接与特征矩阵相乘,就将节点自身的信息丢失了。所以为了保留节点(Node)的自身特征,需要将邻接矩阵加上单位矩阵。
这里只从直觉上说明这个问题,(事实上,我还没有来得及去详细推导公式,后续有时间补上,嘿嘿)
假设节点A的邻居只有B,为了对A进行信息聚合,最直接的方法就是平均贡献,即:New_A = 0.5* A + 0.5 * B,这样的做法看起来很合理,但是如果B的邻居非常多(极端情况是,它与图Graph上的所有其它节点都有连接),经过特征聚合之后,图(Graph)上许多的特征就会非常相似,因此需要考虑节点(Node)的度,不要让度过大的节点(Node)贡献过大。
两层的GCN网络实现如下:
from __future__ import print_function
import tensorflow as tf
class GraphConvolution(tf.keras.layers.Layer):
"""Basic graph convolution layer as in https://arxiv.org/abs/1609.02907"""
def __init__(self, units, support=1,
activation=None,
use_bias=True
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
):
super(GraphConvolution, self).__init__()
self.units = units
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
self.supports_masking = True
self.support = support
assert support >= 1
def build(self, input_shapes):
features_shape = input_shapes[0]
assert len(features_shape) == 2
input_dim = features_shape[1]
self.kernel = self.add_weight(shape = (input_dim * self.support, self.units),
initializer = self.kernel_initializer,
name = 'kernel',
regularizer = self.kernel_regularizer)
if self.use_bias:
self.bias = self.add_weight(shape=(self.units,),
initializer=self.bias_initializer,
name='bias',
regularizer = self.kernel_regularizer)
else:
self.bias = None
self.built = True
def call(self, inputs, mask=None):
features = inputs[0]
basis = inputs[1:]
supports = list()
for i in range(self.support):
supports.append(K.dot(basis[i], features))
supports = K.concatenate(supports, axis=1)
output = K.dot(supports, self.kernel)
if self.bias:
output += self.bias
return self.activation(output)
数据处理的细节前面都大概提过,这里需要注意的是,在数据处理的过程中,还需要对每个Node的Feature做归一化处理。
from graph import GraphConvolutionLayer, GraphConvolutionModel
from dataset import CoraData
import time
import tensorflow as tf
import matplotlib.pyplot as plt
dataset = CoraData()
features, labels, adj, train_mask, val_mask, test_mask = dataset.data()
graph = [features, adj]
Process data ...
Loading cora dataset...
Dataset has 2708 nodes, 2708 edges, 1433 features.
Loss函数中,只对训练数据(train_mask为True)进行Loss计算。
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
def loss(model, x, y, train_mask, training):
y_ = model(x, training=training)
test_mask_logits = tf.gather_nd(y_, tf.where(train_mask))
masked_labels = tf.gather_nd(y, tf.where(train_mask))
return loss_object(y_true=masked_labels, y_pred=test_mask_logits)
def grad(model, inputs, targets, train_mask):
with tf.GradientTape() as tape:
loss_value = loss(model, inputs, targets, train_mask, training=True)
return loss_value, tape.gradient(loss_value, model.trainable_variables)
def test(mask):
logits = model(graph)
test_mask_logits = tf.gather_nd(logits, tf.where(mask))
masked_labels = tf.gather_nd(labels, tf.where(mask))
ll = tf.math.equal(tf.math.argmax(masked_labels, -1), tf.math.argmax(test_mask_logits, -1))
accuarcy = tf.reduce_mean(tf.cast(ll, dtype=tf.float64))
return accuarcy
model = GraphConvolutionModel()
optimizer=tf.keras.optimizers.Adam(learning_rate=0.01, decay=5e-5)
# 记录过程值,以便最后可视化
train_loss_results = []
train_accuracy_results = []
train_val_results = []
train_test_results = []
num_epochs = 200
for epoch in range(num_epochs):
loss_value, grads = grad(model, graph, labels, train_mask)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
accuarcy = test(train_mask)
val_acc = test(val_mask)
test_acc = test(test_mask)
train_loss_results.append(loss_value)
train_accuracy_results.append(accuarcy)
train_val_results.append(val_acc)
train_test_results.append(test_acc)
print("Epoch {} loss={} accuracy={} val_acc={} test_acc={}".format(epoch, loss_value, accuarcy, val_acc, test_acc))
Epoch 0 loss=1.9472886323928833 accuracy=0.4066666666666667 val_acc=0.3142857142857143 test_acc=0.2817028985507246
Epoch 1 loss=1.9314587116241455 accuracy=0.4866666666666667 val_acc=0.3742857142857143 test_acc=0.33106884057971014
Epoch 2 loss=1.9133251905441284 accuracy=0.5 val_acc=0.38571428571428573 test_acc=0.34782608695652173
Epoch 3 loss=1.8908278942108154 accuracy=0.5266666666666666 val_acc=0.3942857142857143 test_acc=0.3496376811594203
Epoch 4 loss=1.8662141561508179 accuracy=0.5533333333333333 val_acc=0.3942857142857143 test_acc=0.3423913043478261
Epoch 5 loss=1.8400791883468628 accuracy=0.56 val_acc=0.38285714285714284 test_acc=0.3401268115942029
Epoch 6 loss=1.8119205236434937 accuracy=0.5866666666666667 val_acc=0.38571428571428573 test_acc=0.33605072463768115
Epoch 7 loss=1.78205144405365 accuracy=0.6066666666666667 val_acc=0.37714285714285717 test_acc=0.33016304347826086
Epoch 8 loss=1.751450777053833 accuracy=0.6066666666666667 val_acc=0.38857142857142857 test_acc=0.33016304347826086
Epoch 9 loss=1.7200360298156738 accuracy=0.6066666666666667 val_acc=0.4057142857142857 test_acc=0.3342391304347826
Epoch 10 loss=1.6870578527450562 accuracy=0.6333333333333333 val_acc=0.42 test_acc=0.3428442028985507
Epoch 11 loss=1.6523456573486328 accuracy=0.64 val_acc=0.43142857142857144 test_acc=0.34646739130434784
Epoch 12 loss=1.616371512413025 accuracy=0.6333333333333333 val_acc=0.44 test_acc=0.35190217391304346
Epoch 13 loss=1.579743504524231 accuracy=0.64 val_acc=0.44857142857142857 test_acc=0.360054347826087
Epoch 14 loss=1.5426799058914185 accuracy=0.64 val_acc=0.4542857142857143 test_acc=0.36594202898550726
Epoch 15 loss=1.5049867630004883 accuracy=0.6466666666666666 val_acc=0.46285714285714286 test_acc=0.3686594202898551
Epoch 16 loss=1.466316819190979 accuracy=0.6666666666666666 val_acc=0.46285714285714286 test_acc=0.37273550724637683
Epoch 17 loss=1.4266818761825562 accuracy=0.6733333333333333 val_acc=0.4857142857142857 test_acc=0.37726449275362317
Epoch 18 loss=1.3862168788909912 accuracy=0.6866666666666666 val_acc=0.5 test_acc=0.3808876811594203
Epoch 19 loss=1.3451327085494995 accuracy=0.7266666666666667 val_acc=0.5114285714285715 test_acc=0.39221014492753625
Epoch 20 loss=1.3035770654678345 accuracy=0.7533333333333333 val_acc=0.5257142857142857 test_acc=0.396286231884058
Epoch 21 loss=1.2615602016448975 accuracy=0.7866666666666666 val_acc=0.5342857142857143 test_acc=0.40806159420289856
Epoch 22 loss=1.2191429138183594 accuracy=0.8 val_acc=0.5457142857142857 test_acc=0.40851449275362317
Epoch 23 loss=1.1763759851455688 accuracy=0.82 val_acc=0.56 test_acc=0.41893115942028986
Epoch 24 loss=1.133314609527588 accuracy=0.82 val_acc=0.5742857142857143 test_acc=0.4316123188405797
Epoch 25 loss=1.09010648727417 accuracy=0.8666666666666667 val_acc=0.5771428571428572 test_acc=0.4429347826086957
Epoch 26 loss=1.0468487739562988 accuracy=0.8666666666666667 val_acc=0.5942857142857143 test_acc=0.452445652173913
Epoch 27 loss=1.0036686658859253 accuracy=0.9 val_acc=0.6028571428571429 test_acc=0.46195652173913043
Epoch 28 loss=0.9607122540473938 accuracy=0.9466666666666667 val_acc=0.6171428571428571 test_acc=0.47101449275362317
Epoch 29 loss=0.9181469678878784 accuracy=0.9666666666666667 val_acc=0.6257142857142857 test_acc=0.47690217391304346
Epoch 30 loss=0.8761056661605835 accuracy=0.9666666666666667 val_acc=0.6371428571428571 test_acc=0.48777173913043476
Epoch 31 loss=0.8347143530845642 accuracy=0.9666666666666667 val_acc=0.6485714285714286 test_acc=0.4986413043478261
Epoch 32 loss=0.79410320520401 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5140398550724637
Epoch 33 loss=0.7544015645980835 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5253623188405797
Epoch 34 loss=0.7157045602798462 accuracy=0.9666666666666667 val_acc=0.6628571428571428 test_acc=0.5335144927536232
Epoch 35 loss=0.6780853271484375 accuracy=0.9666666666666667 val_acc=0.6657142857142857 test_acc=0.5443840579710145
Epoch 36 loss=0.6416434049606323 accuracy=0.98 val_acc=0.6771428571428572 test_acc=0.5498188405797102
Epoch 37 loss=0.6064579486846924 accuracy=0.9866666666666667 val_acc=0.6685714285714286 test_acc=0.5588768115942029
Epoch 38 loss=0.5725471377372742 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5706521739130435
Epoch 39 loss=0.5399670600891113 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5765398550724637
Epoch 40 loss=0.5087469220161438 accuracy=1.0 val_acc=0.68 test_acc=0.5806159420289855
Epoch 41 loss=0.4788845181465149 accuracy=1.0 val_acc=0.68 test_acc=0.5869565217391305
Epoch 42 loss=0.4503992795944214 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.5892210144927537
Epoch 43 loss=0.4233132302761078 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5932971014492754
Epoch 44 loss=0.3976200222969055 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5960144927536232
Epoch 45 loss=0.3733169138431549 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5973731884057971
Epoch 46 loss=0.3503628969192505 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 47 loss=0.32872718572616577 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6014492753623188
Epoch 48 loss=0.3083726167678833 accuracy=1.0 val_acc=0.7 test_acc=0.6023550724637681
Epoch 49 loss=0.28924670815467834 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174
Epoch 50 loss=0.27130240201950073 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6023550724637681
Epoch 51 loss=0.2544911503791809 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174
Epoch 52 loss=0.23875992000102997 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6041666666666666
Epoch 53 loss=0.22406704723834991 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6059782608695652
Epoch 54 loss=0.21036122739315033 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6077898550724637
Epoch 55 loss=0.1975751519203186 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6073369565217391
Epoch 56 loss=0.18565499782562256 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6064311594202898
Epoch 57 loss=0.17455774545669556 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6073369565217391
Epoch 58 loss=0.16423100233078003 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6082427536231884
Epoch 59 loss=0.15462662279605865 accuracy=1.0 val_acc=0.7 test_acc=0.6096014492753623
Epoch 60 loss=0.14569756388664246 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363
Epoch 61 loss=0.13739608228206635 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363
Epoch 62 loss=0.12968382239341736 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6127717391304348
Epoch 63 loss=0.12251587212085724 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6127717391304348
Epoch 64 loss=0.11585451662540436 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6118659420289855
Epoch 65 loss=0.10966223478317261 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609
Epoch 66 loss=0.10390333086252213 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609
Epoch 67 loss=0.09854617714881897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6109601449275363
Epoch 68 loss=0.09356522560119629 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6105072463768116
Epoch 69 loss=0.08893042057752609 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6100543478260869
Epoch 70 loss=0.08461488038301468 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 71 loss=0.08059524744749069 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6105072463768116
Epoch 72 loss=0.07684874534606934 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 73 loss=0.0733552798628807 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623
Epoch 74 loss=0.07009640336036682 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 75 loss=0.06705603748559952 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623
Epoch 76 loss=0.06421645730733871 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6091485507246377
Epoch 77 loss=0.06155867129564285 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 78 loss=0.05906983092427254 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 79 loss=0.05673719570040703 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 80 loss=0.054548606276512146 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377
Epoch 81 loss=0.052494876086711884 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377
Epoch 82 loss=0.050564948469400406 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 83 loss=0.04874930530786514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 84 loss=0.04704001545906067 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6086956521739131
Epoch 85 loss=0.04542906954884529 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 86 loss=0.04390912503004074 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6077898550724637
Epoch 87 loss=0.04247550293803215 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6082427536231884
Epoch 88 loss=0.04112052172422409 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 89 loss=0.03983796760439873 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 90 loss=0.03862294182181358 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 91 loss=0.03747102618217468 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6064311594202898
Epoch 92 loss=0.03637789562344551 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6059782608695652
Epoch 93 loss=0.035339970141649246 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116
Epoch 94 loss=0.03435278683900833 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116
Epoch 95 loss=0.03341297432780266 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 96 loss=0.03251757472753525 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.603713768115942
Epoch 97 loss=0.03166373074054718 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6032608695652174
Epoch 98 loss=0.03084862045943737 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 99 loss=0.030070148408412933 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914
Epoch 100 loss=0.029325664043426514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 101 loss=0.028613392263650894 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914
Epoch 102 loss=0.02793121710419655 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 103 loss=0.027277110144495964 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 104 loss=0.02665024995803833 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 105 loss=0.026048408821225166 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 106 loss=0.02547014318406582 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 107 loss=0.024914324283599854 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 108 loss=0.02437940239906311 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 109 loss=0.023864200338721275 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 110 loss=0.023367829620838165 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 111 loss=0.02288944460451603 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 112 loss=0.022427884861826897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 113 loss=0.021982161328196526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 114 loss=0.021551571786403656 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 115 loss=0.0211354810744524 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6032608695652174
Epoch 116 loss=0.020733091980218887 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6028079710144928
Epoch 117 loss=0.020343618467450142 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6019021739130435
Epoch 118 loss=0.019966619089245796 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188
Epoch 119 loss=0.019601713865995407 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188
Epoch 120 loss=0.01924823224544525 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 121 loss=0.01890559121966362 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 122 loss=0.018573053181171417 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 123 loss=0.018250416964292526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 124 loss=0.017937207594513893 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695
Epoch 125 loss=0.017632879316806793 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695
Epoch 126 loss=0.017337223514914513 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449
Epoch 127 loss=0.017049791291356087 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 128 loss=0.01677037589251995 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 129 loss=0.016498537734150887 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 130 loss=0.016234181821346283 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 131 loss=0.01597682572901249 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 132 loss=0.015726083889603615 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449
Epoch 133 loss=0.015481753274798393 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 134 loss=0.01524385903030634 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 135 loss=0.015011751092970371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 136 loss=0.014785613864660263 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 137 loss=0.014565042220056057 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 138 loss=0.014349889941513538 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 139 loss=0.014139854349195957 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 140 loss=0.013934796676039696 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 141 loss=0.013734581880271435 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 142 loss=0.013539088889956474 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 143 loss=0.013348042033612728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 144 loss=0.01316142175346613 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 145 loss=0.012978975661098957 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 146 loss=0.01280051190406084 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 147 loss=0.012626114301383495 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 148 loss=0.012455460615456104 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 149 loss=0.012288510799407959 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 150 loss=0.012125165201723576 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 151 loss=0.011965337209403515 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 152 loss=0.011808915995061398 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 153 loss=0.011655798181891441 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 154 loss=0.011505785398185253 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 155 loss=0.011358906514942646 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 156 loss=0.011214920319616795 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 157 loss=0.011073877103626728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 158 loss=0.010935710743069649 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 159 loss=0.010800261981785297 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 160 loss=0.010667545720934868 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 161 loss=0.01053738035261631 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 162 loss=0.010409791953861713 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 163 loss=0.010284650139510632 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 164 loss=0.010161920450627804 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 165 loss=0.010041462257504463 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 166 loss=0.009923247620463371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 167 loss=0.009807263500988483 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 168 loss=0.009693419560790062 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 169 loss=0.009581669233739376 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 170 loss=0.009471924044191837 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 171 loss=0.009364166297018528 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 172 loss=0.00925836805254221 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 173 loss=0.009154461324214935 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 174 loss=0.009052390232682228 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 175 loss=0.008952111005783081 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 176 loss=0.008853581734001637 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 177 loss=0.008756755851209164 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 178 loss=0.008661641739308834 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 179 loss=0.008568093180656433 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957
Epoch 180 loss=0.008476126939058304 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957
Epoch 181 loss=0.008385734632611275 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957
Epoch 182 loss=0.008296912536025047 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957
Epoch 183 loss=0.008209548890590668 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 184 loss=0.008123602718114853 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 185 loss=0.008039114996790886 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 186 loss=0.007956001907587051 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 187 loss=0.00787423737347126 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 188 loss=0.0077937874011695385 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 189 loss=0.00771462032571435 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 190 loss=0.007636724505573511 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 191 loss=0.0075600543059408665 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 192 loss=0.007484584581106901 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 193 loss=0.007410289254039526 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 194 loss=0.0073371464386582375 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 195 loss=0.007265167310833931 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 196 loss=0.007194266188889742 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 197 loss=0.007124484982341528 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 198 loss=0.007055748254060745 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 199 loss=0.006988039705902338 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
可以看到,经过200次迭代后,最终GCN网络在验证集上的准确率达到70%左右,在测试集中的Accuracy达到了60%左右。
# 训练过程可视化
fig, axes = plt.subplots(4, sharex=True, figsize=(12, 8))
fig.suptitle('Training Metrics')
axes[0].set_ylabel("Loss", fontsize=14)
axes[0].plot(train_loss_results)
axes[1].set_ylabel("Accuracy", fontsize=14)
axes[1].plot(train_accuracy_results)
axes[2].set_ylabel("Val Acc", fontsize=14)
axes[2].plot(train_val_results)
axes[3].set_ylabel("Test Acc", fontsize=14)
axes[3].plot(train_test_results)
plt.show()