前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图卷积神经网络入门实战

图卷积神经网络入门实战

作者头像
YoungTimes
发布2022-04-28 19:20:35
7310
发布2022-04-28 19:20:35
举报
文章被收录于专栏:半杯茶的小酒杯

Multi-layer Graph Convolutional Network (GCN) with first-order filters,来源:http://tkipf.github.io/graph-convolutional-networks/

本文完整代码和数据已经上传到Github,希望大家不吝赐教,感谢! https://github.com/YoungTimes/GNN/tree/master/GCN

1. GCN是什么

图卷积神经网络(Graph Convolution Networks, GCN)跟CNN一样是特征提取的工具,CNN在处理规则数据结构(如图片等)方面非常强大。

图像矩阵示意图(Euclidean Structure),图片来源【4】

但是在现实世界中,很多数据结构是不规则的,典型的就是图结构,如社交网络、知识图谱等,GNN就比较擅长处理这类数据。

社交网络拓扑示意(Non Euclidean Structure),图片来源【4】

本文主要通过一个完整的GCN对论文进行分类的例子,来展示GCN的工作过程和原理。

这里我们使用的Cora数据集,该数据集由2708篇论文的特征、分类以及它们之间引用关系的5429条边组成,这些论文的类型被划分为7个类别:Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。

最终实现的目标是,输入一篇论文的特征,就可以输出该论文属于哪个分类。

2. 数据集-Cora Dataset

2.1 下载地址

https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

2.2 数据内容

Cora Dataset是对Machine Learning Paper进行分类的数据集,它包含三个文件:

-- README: 对数据集的介绍;

-- cora.cites: 论文之间的引用关系图。文件中每行包含两个Paper ID, 第一个ID是被引用的Paper ID;第二个是引用的Paper ID。格式如下:

-- cora.content: 包含了2708篇Paper的信息,每行的数据格式如下: <paper_id> <word_attributes>+ <class_label>。paper id是论文的唯一标识;word_attributes是是一个维度为1433的词向量,词向量的每个元素对应一个词,0表示该元素对应的词不在Paper中,1表示该元素对应的词在Paper中。class_label是论文的类别,每篇Paper被映射到如下7个分类之一: Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。

先看看Cora Dataset中的数据是什么样的...

代码语言:javascript
复制
import pandas as pd
import numpy as np

# 导入数据:分隔符为Tab
raw_data_content = pd.read_csv('data/cora/cora.content',sep = '\t',header = None)

# [2708 * 1435]
(row, col) = raw_data_content.shape
print("Cora Contents’s Row: {}, Col: {}".format(row, col))

print("=============================================")

# 每行是1435维的向量,第一维是论文的ID,最后一维是论文的Label
raw_data_sample = raw_data_content.head(3)
features_sample =raw_data_sample.iloc[:,1:-1]
labels_sample = raw_data_sample.iloc[:, -1]
labels_onehot_sample = pd.get_dummies(labels_sample)


print("features:{}".format(features_sample))

print("=============================================")

print("labels:{}".format(labels_sample))


print("=============================================")

print("labels one hot:{}".format(labels_onehot_sample))

代码语言:javascript
复制
Cora Contents’s Row: 2708, Col: 1435
=============================================
features:   1     2     3     4     5     6     7     8     9     10    ...  1424  \
0     0     0     0     0     0     0     0     0     0     0  ...     0
1     0     0     0     0     0     0     0     0     0     0  ...     0
2     0     0     0     0     0     0     0     0     0     0  ...     0

   1425  1426  1427  1428  1429  1430  1431  1432  1433
0     0     0     1     0     0     0     0     0     0
1     0     1     0     0     0     0     0     0     0
2     0     0     0     0     0     0     0     0     0

[3 rows x 1433 columns]
=============================================
labels:0           Neural_Networks
1             Rule_Learning
2    Reinforcement_Learning
Name: 1434, dtype: object
=============================================
labels one hot:   Neural_Networks  Reinforcement_Learning  Rule_Learning
0                1                       0              0
1                0                       0              1
2                0                       1              0
代码语言:javascript
复制
raw_data_cites = pd.read_csv('data/cora/cora.cites',sep = '\t',header = None)

# [5429 * 2]
(row, col) = raw_data_cites.shape

print("Cora Cites’s Row: {}, Col: {}".format(row, col))

print("=============================================")

raw_data_cites_sample = raw_data_cites.head(10)

print(raw_data_cites_sample)

print("=============================================")

# raw_data_cites.head(10).values.flatten().tolist()


# Convert Cite to adj matrix
idx = np.array(raw_data_content.iloc[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}


edge_indexs = np.array(list(map(idx_map.get, raw_data_cites.values.flatten())), dtype=np.int32)
edge_indexs = edge_indexs.reshape(raw_data_cites.shape)
adjacency = sp.coo_matrix((np.ones(len(edge_indexs)),
            (edge_indexs[:, 0], edge_indexs[:, 1])),
            shape=(edge_indexs.shape[0], edge_indexs.shape[0]), dtype="float32")

print(adjacency)
代码语言:javascript
复制
Cora Cites’s Row: 5429, Col: 2
           0        1
0         35     1033
1         35   103482
2         35   103515
3         35  1050679
4         35  1103960
...      ...      ...
5424  853116    19621
5425  853116   853155
5426  853118  1140289
5427  853155   853118
5428  954315  1155073

[5429 rows x 2 columns]
=============================================
  (163, 402)	1.0
  (163, 659)	1.0
  (163, 1696)	1.0
  (163, 2295)	1.0
  (163, 1274)	1.0
  (163, 1286)	1.0
  (163, 1544)	1.0
  (163, 2600)	1.0
  (163, 2363)	1.0
  (163, 1905)	1.0
  (163, 1611)	1.0
  (163, 141)	1.0
  (163, 1807)	1.0
  (163, 1110)	1.0
  (163, 174)	1.0
  (163, 2521)	1.0
  (163, 1792)	1.0
  (163, 1675)	1.0
  (163, 1334)	1.0
  (163, 813)	1.0
  (163, 1799)	1.0
  (163, 1943)	1.0
  (163, 2077)	1.0
  (163, 765)	1.0
  (163, 769)	1.0
  :	:
  (2228, 1093)	1.0
  (2228, 1094)	1.0
  (2228, 2068)	1.0
  (2228, 2085)	1.0
  (2694, 2331)	1.0
  (617, 226)	1.0
  (422, 1691)	1.0
  (2142, 2096)	1.0
  (1477, 1252)	1.0
  (1485, 1252)	1.0
  (2185, 2109)	1.0
  (2117, 2639)	1.0
  (1211, 1247)	1.0
  (1884, 745)	1.0
  (1884, 1886)	1.0
  (1884, 1902)	1.0
  (1885, 745)	1.0
  (1885, 1884)	1.0
  (1885, 1886)	1.0
  (1885, 1902)	1.0
  (1886, 745)	1.0
  (1886, 1902)	1.0
  (1887, 2258)	1.0
  (1902, 1887)	1.0
  (837, 1686)	1.0

3. 构造训练集、测试集和验证集

这里使用[0, 150)个数据作为训练集合,[150, 500)个数据作为验证集,[500, 2708)个数据作为测试集,实现上使用掩码(train_mask、val_mask、test_mask)的形式来区分训练集、验证集和测试集。

代码语言:javascript
复制
train_index = np.arange(150)
val_index = np.arange(150, 500)
test_index = np.arange(500, 2708)

train_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)
val_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)
test_mask = np.zeros(edge_indexs.shape[0], dtype = np.bool)

train_mask[train_index] = True
val_mask[val_index] = True
test_mask[test_index] = True

4. GCN核心网络模型

图(Graph)其实数据结构中最重要的概念之一,对,没错,图神经网的图(Graph)跟数据结构中的图(Graph)是一回事。假设神经网络的输入图(Graph)包含N个节点(Node),每个节点有d个特征,则所有这些节点的特征组成一个Nxd维的矩阵X;两个节点间的邻接关系组成一个NxN的邻接矩阵A(adjacency),则X和A就构成了图神经网络的输入。

4.1 核心公式:

4.2 网络传播过程

如下所示的5个Node组成的无向图的图结构。

它对应的邻接矩阵A为:

A = \left[ \begin{matrix} 0 & 1 & 0 & 0 & 1\\ 1 & 0 & 1 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 1 & 0 & 1 \\ 1 & 0 & 0 & 1 & 0 \end{matrix} \right]

在邻接矩阵基础上加上自环,即与单位矩阵I相加,得到:

\tilde{A} = \left[ \begin{matrix} 1 & 1 & 0 & 0 & 1\\ 1 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 0 \\ 0 & 1 & 1 & 1 & 1 \\ 1 & 0 & 0 & 1 & 1 \end{matrix} \right]

简单期间,这里假设每个Graph中每个Node的特征都是一维的,

X=[1, 2, 3, 4, 5]^T

,

\tilde{A} * X = \left[ \begin{matrix} 1 * 1 + 1 * 2 + 0 * 3 + 0 * 4 + 1 * 5 \\ 1 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 0 * 5 \\ 0 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 0 * 5 \\ 0 * 1 + 1 * 2 + 1 * 3 + 1 * 4 + 1 * 5 \\ 1 * 1 + 0 * 2 + 0 * 3 + 1 * 4 + 1 * 5 \\ \end{matrix} \right]

从上式看到了什么,对,就是各个节点都将与其一阶相邻节点的信息融合到自己的节点中,这也是公式(1)的本质所在。神经网络传播的过程,实际上就是图(Graph)中各个节点(Node)不断聚合邻居节点信息的过程。 通过两次连乘

\tilde{A} * \tilde{A} * X

也就实现各个Node融合自己二阶邻居节点的信息。

到这里我们应该就可以理解为什么要将邻接矩阵加上单位矩阵

因为邻接矩阵的对角线的值都是0,所以如果用邻接矩阵直接与特征矩阵相乘,就将节点自身的信息丢失了。所以为了保留节点(Node)的自身特征,需要将邻接矩阵加上单位矩阵。

这里只从直觉上说明这个问题,(事实上,我还没有来得及去详细推导公式,后续有时间补上,嘿嘿)

假设节点A的邻居只有B,为了对A进行信息聚合,最直接的方法就是平均贡献,即:New_A = 0.5* A + 0.5 * B,这样的做法看起来很合理,但是如果B的邻居非常多(极端情况是,它与图Graph上的所有其它节点都有连接),经过特征聚合之后,图(Graph)上许多的特征就会非常相似,因此需要考虑节点(Node)的度,不要让度过大的节点(Node)贡献过大。

两层的GCN网络实现如下:

代码语言:javascript
复制
from __future__ import print_function

import tensorflow as tf

class GraphConvolution(tf.keras.layers.Layer):
    """Basic graph convolution layer as in https://arxiv.org/abs/1609.02907"""
    def __init__(self, units, support=1,
                 activation=None,
                 use_bias=True
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                ):
        
        super(GraphConvolution, self).__init__()
        self.units = units
        self.use_bias = use_bias
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        
        self.supports_masking = True

        self.support = support
        assert support >= 1

    def build(self, input_shapes):
        features_shape = input_shapes[0]
        assert len(features_shape) == 2
        input_dim = features_shape[1]

        self.kernel = self.add_weight(shape = (input_dim * self.support, self.units),
                                      initializer = self.kernel_initializer,
                                      name = 'kernel',
                                      regularizer = self.kernel_regularizer)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                       regularizer = self.kernel_regularizer)
        else:
            self.bias = None
            
        self.built = True

    def call(self, inputs, mask=None):
        features = inputs[0]
        basis = inputs[1:]

        supports = list()
        for i in range(self.support):
            supports.append(K.dot(basis[i], features))
        supports = K.concatenate(supports, axis=1)
        output = K.dot(supports, self.kernel)

        if self.bias:
            output += self.bias
            
        return self.activation(output)

5. 网络训练过程

5.1 准备训练数据

数据处理的细节前面都大概提过,这里需要注意的是,在数据处理的过程中,还需要对每个Node的Feature做归一化处理。

代码语言:javascript
复制
from graph import GraphConvolutionLayer, GraphConvolutionModel
from dataset import CoraData

import time
import tensorflow as tf
import matplotlib.pyplot as plt

dataset = CoraData()
features, labels, adj, train_mask, val_mask, test_mask = dataset.data()

graph = [features, adj]
代码语言:javascript
复制
Process data ...
Loading cora dataset...
Dataset has 2708 nodes, 2708 edges, 1433 features.

5.2 Loss计算

Loss函数中,只对训练数据(train_mask为True)进行Loss计算。

代码语言:javascript
复制
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

def loss(model, x, y, train_mask, training):

    y_ = model(x, training=training)

    test_mask_logits = tf.gather_nd(y_, tf.where(train_mask))
    masked_labels = tf.gather_nd(y, tf.where(train_mask))

    return loss_object(y_true=masked_labels, y_pred=test_mask_logits)


def grad(model, inputs, targets, train_mask):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets, train_mask, training=True)
    
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

5.3 实际训练过程

代码语言:javascript
复制
def test(mask):
    logits = model(graph)

    test_mask_logits = tf.gather_nd(logits, tf.where(mask))
    masked_labels = tf.gather_nd(labels, tf.where(mask))

    ll = tf.math.equal(tf.math.argmax(masked_labels, -1), tf.math.argmax(test_mask_logits, -1))
    accuarcy = tf.reduce_mean(tf.cast(ll, dtype=tf.float64))

    return accuarcy

model = GraphConvolutionModel()

optimizer=tf.keras.optimizers.Adam(learning_rate=0.01, decay=5e-5)

# 记录过程值,以便最后可视化
train_loss_results = []
train_accuracy_results = []
train_val_results = []
train_test_results = []

num_epochs = 200

for epoch in range(num_epochs):

    loss_value, grads = grad(model, graph, labels, train_mask)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    accuarcy = test(train_mask)
    val_acc = test(val_mask)
    test_acc = test(test_mask)

    train_loss_results.append(loss_value)
    train_accuracy_results.append(accuarcy)
    train_val_results.append(val_acc)
    train_test_results.append(test_acc)

    print("Epoch {} loss={} accuracy={} val_acc={} test_acc={}".format(epoch, loss_value, accuarcy, val_acc, test_acc))
代码语言:javascript
复制
Epoch 0 loss=1.9472886323928833 accuracy=0.4066666666666667 val_acc=0.3142857142857143 test_acc=0.2817028985507246
Epoch 1 loss=1.9314587116241455 accuracy=0.4866666666666667 val_acc=0.3742857142857143 test_acc=0.33106884057971014
Epoch 2 loss=1.9133251905441284 accuracy=0.5 val_acc=0.38571428571428573 test_acc=0.34782608695652173
Epoch 3 loss=1.8908278942108154 accuracy=0.5266666666666666 val_acc=0.3942857142857143 test_acc=0.3496376811594203
Epoch 4 loss=1.8662141561508179 accuracy=0.5533333333333333 val_acc=0.3942857142857143 test_acc=0.3423913043478261
Epoch 5 loss=1.8400791883468628 accuracy=0.56 val_acc=0.38285714285714284 test_acc=0.3401268115942029
Epoch 6 loss=1.8119205236434937 accuracy=0.5866666666666667 val_acc=0.38571428571428573 test_acc=0.33605072463768115
Epoch 7 loss=1.78205144405365 accuracy=0.6066666666666667 val_acc=0.37714285714285717 test_acc=0.33016304347826086
Epoch 8 loss=1.751450777053833 accuracy=0.6066666666666667 val_acc=0.38857142857142857 test_acc=0.33016304347826086
Epoch 9 loss=1.7200360298156738 accuracy=0.6066666666666667 val_acc=0.4057142857142857 test_acc=0.3342391304347826
Epoch 10 loss=1.6870578527450562 accuracy=0.6333333333333333 val_acc=0.42 test_acc=0.3428442028985507
Epoch 11 loss=1.6523456573486328 accuracy=0.64 val_acc=0.43142857142857144 test_acc=0.34646739130434784
Epoch 12 loss=1.616371512413025 accuracy=0.6333333333333333 val_acc=0.44 test_acc=0.35190217391304346
Epoch 13 loss=1.579743504524231 accuracy=0.64 val_acc=0.44857142857142857 test_acc=0.360054347826087
Epoch 14 loss=1.5426799058914185 accuracy=0.64 val_acc=0.4542857142857143 test_acc=0.36594202898550726
Epoch 15 loss=1.5049867630004883 accuracy=0.6466666666666666 val_acc=0.46285714285714286 test_acc=0.3686594202898551
Epoch 16 loss=1.466316819190979 accuracy=0.6666666666666666 val_acc=0.46285714285714286 test_acc=0.37273550724637683
Epoch 17 loss=1.4266818761825562 accuracy=0.6733333333333333 val_acc=0.4857142857142857 test_acc=0.37726449275362317
Epoch 18 loss=1.3862168788909912 accuracy=0.6866666666666666 val_acc=0.5 test_acc=0.3808876811594203
Epoch 19 loss=1.3451327085494995 accuracy=0.7266666666666667 val_acc=0.5114285714285715 test_acc=0.39221014492753625
Epoch 20 loss=1.3035770654678345 accuracy=0.7533333333333333 val_acc=0.5257142857142857 test_acc=0.396286231884058
Epoch 21 loss=1.2615602016448975 accuracy=0.7866666666666666 val_acc=0.5342857142857143 test_acc=0.40806159420289856
Epoch 22 loss=1.2191429138183594 accuracy=0.8 val_acc=0.5457142857142857 test_acc=0.40851449275362317
Epoch 23 loss=1.1763759851455688 accuracy=0.82 val_acc=0.56 test_acc=0.41893115942028986
Epoch 24 loss=1.133314609527588 accuracy=0.82 val_acc=0.5742857142857143 test_acc=0.4316123188405797
Epoch 25 loss=1.09010648727417 accuracy=0.8666666666666667 val_acc=0.5771428571428572 test_acc=0.4429347826086957
Epoch 26 loss=1.0468487739562988 accuracy=0.8666666666666667 val_acc=0.5942857142857143 test_acc=0.452445652173913
Epoch 27 loss=1.0036686658859253 accuracy=0.9 val_acc=0.6028571428571429 test_acc=0.46195652173913043
Epoch 28 loss=0.9607122540473938 accuracy=0.9466666666666667 val_acc=0.6171428571428571 test_acc=0.47101449275362317
Epoch 29 loss=0.9181469678878784 accuracy=0.9666666666666667 val_acc=0.6257142857142857 test_acc=0.47690217391304346
Epoch 30 loss=0.8761056661605835 accuracy=0.9666666666666667 val_acc=0.6371428571428571 test_acc=0.48777173913043476
Epoch 31 loss=0.8347143530845642 accuracy=0.9666666666666667 val_acc=0.6485714285714286 test_acc=0.4986413043478261
Epoch 32 loss=0.79410320520401 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5140398550724637
Epoch 33 loss=0.7544015645980835 accuracy=0.9666666666666667 val_acc=0.6514285714285715 test_acc=0.5253623188405797
Epoch 34 loss=0.7157045602798462 accuracy=0.9666666666666667 val_acc=0.6628571428571428 test_acc=0.5335144927536232
Epoch 35 loss=0.6780853271484375 accuracy=0.9666666666666667 val_acc=0.6657142857142857 test_acc=0.5443840579710145
Epoch 36 loss=0.6416434049606323 accuracy=0.98 val_acc=0.6771428571428572 test_acc=0.5498188405797102
Epoch 37 loss=0.6064579486846924 accuracy=0.9866666666666667 val_acc=0.6685714285714286 test_acc=0.5588768115942029
Epoch 38 loss=0.5725471377372742 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5706521739130435
Epoch 39 loss=0.5399670600891113 accuracy=1.0 val_acc=0.6771428571428572 test_acc=0.5765398550724637
Epoch 40 loss=0.5087469220161438 accuracy=1.0 val_acc=0.68 test_acc=0.5806159420289855
Epoch 41 loss=0.4788845181465149 accuracy=1.0 val_acc=0.68 test_acc=0.5869565217391305
Epoch 42 loss=0.4503992795944214 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.5892210144927537
Epoch 43 loss=0.4233132302761078 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5932971014492754
Epoch 44 loss=0.3976200222969055 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5960144927536232
Epoch 45 loss=0.3733169138431549 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5973731884057971
Epoch 46 loss=0.3503628969192505 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 47 loss=0.32872718572616577 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6014492753623188
Epoch 48 loss=0.3083726167678833 accuracy=1.0 val_acc=0.7 test_acc=0.6023550724637681
Epoch 49 loss=0.28924670815467834 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174
Epoch 50 loss=0.27130240201950073 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6023550724637681
Epoch 51 loss=0.2544911503791809 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6032608695652174
Epoch 52 loss=0.23875992000102997 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6041666666666666
Epoch 53 loss=0.22406704723834991 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6059782608695652
Epoch 54 loss=0.21036122739315033 accuracy=1.0 val_acc=0.7085714285714285 test_acc=0.6077898550724637
Epoch 55 loss=0.1975751519203186 accuracy=1.0 val_acc=0.7057142857142857 test_acc=0.6073369565217391
Epoch 56 loss=0.18565499782562256 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6064311594202898
Epoch 57 loss=0.17455774545669556 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6073369565217391
Epoch 58 loss=0.16423100233078003 accuracy=1.0 val_acc=0.7028571428571428 test_acc=0.6082427536231884
Epoch 59 loss=0.15462662279605865 accuracy=1.0 val_acc=0.7 test_acc=0.6096014492753623
Epoch 60 loss=0.14569756388664246 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363
Epoch 61 loss=0.13739608228206635 accuracy=1.0 val_acc=0.7 test_acc=0.6109601449275363
Epoch 62 loss=0.12968382239341736 accuracy=1.0 val_acc=0.6942857142857143 test_acc=0.6127717391304348
Epoch 63 loss=0.12251587212085724 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6127717391304348
Epoch 64 loss=0.11585451662540436 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6118659420289855
Epoch 65 loss=0.10966223478317261 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609
Epoch 66 loss=0.10390333086252213 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6114130434782609
Epoch 67 loss=0.09854617714881897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6109601449275363
Epoch 68 loss=0.09356522560119629 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6105072463768116
Epoch 69 loss=0.08893042057752609 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6100543478260869
Epoch 70 loss=0.08461488038301468 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 71 loss=0.08059524744749069 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6105072463768116
Epoch 72 loss=0.07684874534606934 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 73 loss=0.0733552798628807 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623
Epoch 74 loss=0.07009640336036682 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6100543478260869
Epoch 75 loss=0.06705603748559952 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6096014492753623
Epoch 76 loss=0.06421645730733871 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6091485507246377
Epoch 77 loss=0.06155867129564285 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 78 loss=0.05906983092427254 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 79 loss=0.05673719570040703 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6086956521739131
Epoch 80 loss=0.054548606276512146 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377
Epoch 81 loss=0.052494876086711884 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6091485507246377
Epoch 82 loss=0.050564948469400406 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 83 loss=0.04874930530786514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 84 loss=0.04704001545906067 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6086956521739131
Epoch 85 loss=0.04542906954884529 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6082427536231884
Epoch 86 loss=0.04390912503004074 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6077898550724637
Epoch 87 loss=0.04247550293803215 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6082427536231884
Epoch 88 loss=0.04112052172422409 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 89 loss=0.03983796760439873 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 90 loss=0.03862294182181358 accuracy=1.0 val_acc=0.6828571428571428 test_acc=0.6073369565217391
Epoch 91 loss=0.03747102618217468 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6064311594202898
Epoch 92 loss=0.03637789562344551 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6059782608695652
Epoch 93 loss=0.035339970141649246 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116
Epoch 94 loss=0.03435278683900833 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.605072463768116
Epoch 95 loss=0.03341297432780266 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 96 loss=0.03251757472753525 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.603713768115942
Epoch 97 loss=0.03166373074054718 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6032608695652174
Epoch 98 loss=0.03084862045943737 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 99 loss=0.030070148408412933 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914
Epoch 100 loss=0.029325664043426514 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6041666666666666
Epoch 101 loss=0.028613392263650894 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6046195652173914
Epoch 102 loss=0.02793121710419655 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 103 loss=0.027277110144495964 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 104 loss=0.02665024995803833 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 105 loss=0.026048408821225166 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 106 loss=0.02547014318406582 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 107 loss=0.024914324283599854 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 108 loss=0.02437940239906311 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 109 loss=0.023864200338721275 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 110 loss=0.023367829620838165 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 111 loss=0.02288944460451603 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 112 loss=0.022427884861826897 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.603713768115942
Epoch 113 loss=0.021982161328196526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 114 loss=0.021551571786403656 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6041666666666666
Epoch 115 loss=0.0211354810744524 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6032608695652174
Epoch 116 loss=0.020733091980218887 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6028079710144928
Epoch 117 loss=0.020343618467450142 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6019021739130435
Epoch 118 loss=0.019966619089245796 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188
Epoch 119 loss=0.019601713865995407 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6014492753623188
Epoch 120 loss=0.01924823224544525 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 121 loss=0.01890559121966362 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 122 loss=0.018573053181171417 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 123 loss=0.018250416964292526 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 124 loss=0.017937207594513893 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695
Epoch 125 loss=0.017632879316806793 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6005434782608695
Epoch 126 loss=0.017337223514914513 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449
Epoch 127 loss=0.017049791291356087 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 128 loss=0.01677037589251995 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 129 loss=0.016498537734150887 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 130 loss=0.016234181821346283 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 131 loss=0.01597682572901249 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.5996376811594203
Epoch 132 loss=0.015726083889603615 accuracy=1.0 val_acc=0.6857142857142857 test_acc=0.6000905797101449
Epoch 133 loss=0.015481753274798393 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 134 loss=0.01524385903030634 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 135 loss=0.015011751092970371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6009963768115942
Epoch 136 loss=0.014785613864660263 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 137 loss=0.014565042220056057 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 138 loss=0.014349889941513538 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 139 loss=0.014139854349195957 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 140 loss=0.013934796676039696 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 141 loss=0.013734581880271435 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 142 loss=0.013539088889956474 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 143 loss=0.013348042033612728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 144 loss=0.01316142175346613 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 145 loss=0.012978975661098957 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 146 loss=0.01280051190406084 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 147 loss=0.012626114301383495 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.6000905797101449
Epoch 148 loss=0.012455460615456104 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 149 loss=0.012288510799407959 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 150 loss=0.012125165201723576 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 151 loss=0.011965337209403515 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 152 loss=0.011808915995061398 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 153 loss=0.011655798181891441 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 154 loss=0.011505785398185253 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 155 loss=0.011358906514942646 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 156 loss=0.011214920319616795 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 157 loss=0.011073877103626728 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 158 loss=0.010935710743069649 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 159 loss=0.010800261981785297 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 160 loss=0.010667545720934868 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 161 loss=0.01053738035261631 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 162 loss=0.010409791953861713 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 163 loss=0.010284650139510632 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 164 loss=0.010161920450627804 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 165 loss=0.010041462257504463 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 166 loss=0.009923247620463371 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 167 loss=0.009807263500988483 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 168 loss=0.009693419560790062 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 169 loss=0.009581669233739376 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 170 loss=0.009471924044191837 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6005434782608695
Epoch 171 loss=0.009364166297018528 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 172 loss=0.00925836805254221 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 173 loss=0.009154461324214935 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 174 loss=0.009052390232682228 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.6000905797101449
Epoch 175 loss=0.008952111005783081 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 176 loss=0.008853581734001637 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 177 loss=0.008756755851209164 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 178 loss=0.008661641739308834 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5996376811594203
Epoch 179 loss=0.008568093180656433 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957
Epoch 180 loss=0.008476126939058304 accuracy=1.0 val_acc=0.6885714285714286 test_acc=0.5991847826086957
Epoch 181 loss=0.008385734632611275 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957
Epoch 182 loss=0.008296912536025047 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5991847826086957
Epoch 183 loss=0.008209548890590668 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 184 loss=0.008123602718114853 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 185 loss=0.008039114996790886 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 186 loss=0.007956001907587051 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 187 loss=0.00787423737347126 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 188 loss=0.0077937874011695385 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 189 loss=0.00771462032571435 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 190 loss=0.007636724505573511 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 191 loss=0.0075600543059408665 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 192 loss=0.007484584581106901 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 193 loss=0.007410289254039526 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 194 loss=0.0073371464386582375 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 195 loss=0.007265167310833931 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 196 loss=0.007194266188889742 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 197 loss=0.007124484982341528 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 198 loss=0.007055748254060745 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203
Epoch 199 loss=0.006988039705902338 accuracy=1.0 val_acc=0.6914285714285714 test_acc=0.5996376811594203

可以看到,经过200次迭代后,最终GCN网络在验证集上的准确率达到70%左右,在测试集中的Accuracy达到了60%左右。

5.4 训练过程可视化

代码语言:javascript
复制
# 训练过程可视化
fig, axes = plt.subplots(4, sharex=True, figsize=(12, 8))
fig.suptitle('Training Metrics')

axes[0].set_ylabel("Loss", fontsize=14)
axes[0].plot(train_loss_results)

axes[1].set_ylabel("Accuracy", fontsize=14)
axes[1].plot(train_accuracy_results)

axes[2].set_ylabel("Val Acc", fontsize=14)
axes[2].plot(train_val_results)

axes[3].set_ylabel("Test Acc", fontsize=14)
axes[3].plot(train_test_results)

plt.show()

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-06-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 半杯茶的小酒杯 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. GCN是什么
  • 2. 数据集-Cora Dataset
    • 2.1 下载地址
      • 2.2 数据内容
      • 3. 构造训练集、测试集和验证集
      • 4. GCN核心网络模型
        • 4.1 核心公式:
          • 4.2 网络传播过程
          • 5. 网络训练过程
            • 5.1 准备训练数据
              • 5.2 Loss计算
                • 5.3 实际训练过程
                  • 5.4 训练过程可视化
                  相关产品与服务
                  灰盒安全测试
                  腾讯知识图谱(Tencent Knowledge Graph,TKG)是一个集成图数据库、图计算引擎和图可视化分析的一站式平台。支持抽取和融合异构数据,支持千亿级节点关系的存储和计算,支持规则匹配、机器学习、图嵌入等图数据挖掘算法,拥有丰富的图数据渲染和展现的可视化方案。
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档