GCN 从理论到源码分析
GNN 简介
在 CNN 中,模型的输入数据一般都是固定宽高的数据,比如图片,这类数据在形状上是规则的,称为欧几里得结构化数‘据。而许多重要的现实世界数据集以图表或网络的形式出现:社会网络、知识图表、蛋白质相互作用网络、万维网等,这种数据是不规则的,没有固定形状,称为非欧几里得结构化数据。虽然 CNN 取得了巨大的成功,然而却不能直接把 CNN 应用到这种数据上。近年来一些研究致力于把 RNN 和 CNN 推广到非欧几里得结构化数据,出现了很多新的算法。包括谱卷积和空间卷积,其中Defferrard et al. (NIPS 2016) 基于类神经网络模型学习的自由参数Chebyshev多项式的谱域近似平滑滤波器在 MNIST 数据集上取得了很好的效果。而基于谱卷积 (spectral graph convolutions),衍生出了Kipf & Welling (ICLR 2017),不仅运算更加快速,而且在一些基准数据集上的准确率也更高。
GCN 参数定义
大多数 GNN (graph neural network) 都有一个通用的体系结构,我们把这些称为 GCN (graph convolutional network),因为参数在整个图或者部分图上都是共享的,因此称为卷积 (convolutional)。这些模型的目标是学习一个函数,可以对图上的特征进行建模,图定义为 \(\mathcal{G}=(\mathcal{V}, \mathcal{E})\),其中 \(\mathcal{V}\) 表示节点, \(\mathcal{E}\) 表示边。输入包括两部分
- 特征矩阵 \(X\),形状是 \(N \times D\),\(N\) 表示节点的数量,\(D\) 表示每个节点特征的数量。每个节点的特征向量表示为 \(x_{i}\)。
- 表示图结构的矩阵,一般是图的邻接矩阵 \(A\)。
输出有两种形式,一种是节点级别的输出特征矩阵 \(Z\),形状是 \(N \times F\),\(N\) 表示节点的数量,\(F\) 表示输出的每个节点特征的数量;另一种是图级别的输出标量 \(z\),可以通过引入某种形式的池化操作来得到。因此每一层 GCN 可以表示如下:
\[ \begin{align}X^{(l+1)}=f\left(X^{(l)}, A\right) \end{align}\]
其中 \(X^{(l)}\) 表示第 \(l\) 层输入的特征矩阵,\(X^{(l+1)}\) 表示第 \(l\) 层输出的特征矩阵,\(f(\cdot, \cdot)\) 表示一种映射或者变换方法。不同的 \(f(\cdot, \cdot)\) 形成不同的模型。我们主要研究的就是 \(f(\cdot, \cdot)\) 的选择和参数。
GCN 的一个例子
我们把上面一般形式的 GCN 表达式具体化 :
\[f\left(X^{(l)}, A\right)=\sigma\left(A X^{(l)} W^{(l)}\right)\]
其中 \(W^{(l)}\) 是第 \(l\) 层的权重矩阵,\(\sigma(\cdot)\) 是某种非线性激活函数,如 \(ReLU\),虽然这个表达式很简单,但是却很强大。但首先让我们来看下这个公式存在的 2 个问题和处理方法。
- 第一个问题是邻接矩阵 \(A\) 的对角线上的元素全为 0,表示没有自己到自己的边。所以邻接矩阵 \(A\) 与特征矩阵 \(X^{(l)}\) 相乘表示对于每个节点,把与该节点相邻节点的特征向量相加,除了该节点自身的特征向量。所以需要加上自己到自己的边,称为 self-loops,使得邻接矩阵对角线上的元素全为 1。要添加 self-loops,只需要把邻接矩阵 \(A\) 加上一个单位矩阵即可。
- 第二个问题是邻接矩阵 \(A\) 没有标准化,因此 \(A X^{(l)} W^{(l)}\) 的结果会改变特征矩阵 \(X\) 的元素取值的数量级 (尺度),我们以可通过观察矩阵 \(A\) 的特征值来理解这一点 (特征值在某种意义上可以反应这个矩阵的大小)。将矩阵 \(A\) 标准化之后,使得每行的和为 1 就可以了,如 \(D^{-1} A\),其中 \(D\) 是对角节点度矩阵 (diagonal node degree matrix)。 \(D^{-1}A X^{(l)}\) 相当于计算相邻节点特征的平均值。实际上,当我们使用对称规范化时,特征的变换变得更有趣,如 \(D^{-\frac{1}{2}} A D^{-\frac{1}{2}}\),这种变化不再相当于计算相邻节点特征的平均值。
经过上面的分析,最终的前向传播表达式为:
\[f\left(X^{(l)}, A\right)=\sigma\left(\hat{D}^{-\frac{1}{2}} \hat{A} \hat{D}^{-\frac{1}{2}} X^{(l)} W^{(l)}\right)\]
其中 \(\hat{A}=A+I\),\(I\) 是单位矩阵。\(\hat{D}\) 是矩阵 \(\hat{A}\) 的对角节点度矩阵 (diagonal node degree matrix)。
对 karate club network 进行 Embedding
karate club network 是一个简单的图数据集,如下图,不同的颜色表示通过聚类获得的 communities。

这里我们使用 3 层的 GCN,权重矩阵 \(W\) 是随机初始化的,特征矩阵 \(X=I\),表示所有节点没有特征。3 层 GCN 在前向传播的过程中有效地卷积每个节点的 3 阶邻域。值得注意的是,该模型生成了这些节点的 Embedding,这些节点与图的 Community 结构非常相似(参见下图)。而我们的权重是随机初始化得到的,并且还没有执行反向传播更新参数,也就是说我们还没有学习模型,就已经得到了不错的结果。

这个结果有些令人惊讶。最近一篇关于关于 DeepWalk 模型的论文 (Perozzi et al., KDD 2014) 表明,他们可以在一个复杂的无监督训练过程中学习到非常相似的 Embedding。为什么使用我们简单的没有训练过的 GCN 模型,就可以得到这样一个不错的结果?这看起来像是有点违背天下没有免费的午餐
的定理。
我们可以通过著名的 Weisfeiler-Lehman 算法来阐明这一点,将 GCN 模型解释为 Weisfeiler-Lehman 算法的一个广义的、可微的版本。一维 Weisfeiler-Lehman 算法的工作原理如下:
对于图上的所有点 \(v_{i} \in \mathcal{G}\):
- 获取邻接顶点 \(\left\{v_{j}\right\}\) 的特征集合 \(\left\{x_{v_{j}}\right\}\)
- 更新顶点 \(i\) 的特征 \(x_{v_{i}} \leftarrow \operatorname{hash}\left(\sum_{j} x_{v_{j}}\right)\),其中 \(\operatorname{hash}(\cdot)\) 理想情况下是内射散列函数
重复 \(k\) 次上述步骤或者直到收敛。
实验发现,Weisfeiler-Lehman 算法为大多数图生成了一组独特的特征矩阵,这意味着为图中每个节点生成了一个独特的特征向量,每个节点的向量相当于描述其在图中的功能。不过在高度规则的图没什么效果,如网格、链等。对于大多数不规则图,通过 Weisfeiler-Lehman 算法得到的这个特征可用于检查图是否同构(即两个图是否相同,即使两个图的节点排列不同也能检测出)。
回到 GCN 上面来,上面的一层 GCN 的表达式是矩阵形式,而向量形式如下:
\[x_{v_{i}}^{(l+1)}=\sigma\left(\sum_{j} \frac{1}{c_{i j}} x_{v_{j}}^{(l)} W^{(l)}\right)\]
其中 \(j\) 表示 \(v_{j}\) 的邻接节点,\(c_{i,j}\) 是边 \(\left(v_{i}, v_{j}\right)\) 的标准化常数,来源于经过了对称标准化的邻接矩阵 \(D^{-\frac{1}{2}} A D^{-\frac{1}{2}}\)。上式的向量化前向传播规则,可以解释为原始 Weisfeiler-Lehman 算法中使用的散列函数的可微和参数化,并且初始化随机权重矩阵,使其正交 (权重正交的意义可查看 Glorot & Bengio, AISTATS 2010)。我们观察到了很好的结果,得到了有意义的平滑 Embedding,并且可以使用 Embedding 之间的距离来衡量和解释局部图结构的相似性。
半监督学习
由于模型中的所有参数都是可微的和参数化的,我们可以添加一些标签、训练模型并观察 Embedding 的变化。我们可以使用 Kipf & Welling (ICLR 2017) 中介绍的 GCN 半监督学习算法。在下图中,我们为每个节点标记一个类或者 Community,然后开始迭代训练 300 次,得到了节点特征在 2 维空间的可视化过程。

模型直接产生了一个二维的隐空间,我们将其进行了可视化。我们观察到,3 层 GCN 模型能够线性地分离 Community,每个类只有一个标注的节点。这是一个相当好的结果,因为模型的输入没有包括所有节点的特征。
Torch geometric GCNConv 源码分析
上文中,GCN 的表达式是从整个图的角度来考虑和描述的。从单个节点来说,每个节点的特征向量可以表示为的变换 (前向传播) 的向量形式可以表示为:
\(\mathbf{x}_{i}^{(k)}=\sum_{j \in \mathcal{N}(i) \cup\{i\}} \frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}} \cdot\left(\boldsymbol{W} \cdot \mathbf{x}_{j}^{(k-1)}\right)\)
其中 \(W\) 是权重矩阵 (即模型学习过程中要更新的参数),\(\mathbf{x}_{i}^{(k)}\) 表示节点 \(i\) 在第 \(k\) 次迭代的特征向量,\(deg(i)\) 表示节点 \(i\) 的度,\(\mathcal{N}(i)\) 表示节点 \(i\) 所有邻接节点的集合。
GCNConv源码
目前最新版本的源码已经和下面的代码不一样了,但是原理基本上是一样的。
1 | import torch |
初始化init
1 | def __init__(self, in_channels, out_channels): |
定义self.lin
为线性变换函数,其中in_channel
是每个节点输入特征的维度,out_channels是每个节点输出特征的维度。这里只是定义了结构,具体的特征维度变换逻辑是在forward()
里实现的。这一部分对应着\(XW\)。输入的特征矩阵维度是(N, in_channels)
,输出的特征矩阵维度是(N, out_channels)
,其中 N 是节点个数。
forward
1 | def forward(self, x, edge_index): |
第 1 步 edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
是给邻接矩阵加上 self loops,也即构造出矩阵 \(\hat{A}=A+I\),在torch geometric
中,邻接矩阵表示为 2 维数组,第 1 行表示边的起始节点 (source 节点),第 2 行表示边的目标节点 (target 节点)。如下图

邻接矩阵 \(A\) 可以表示为:
1 | edge_index = torch.tensor([[0, 1, 1, 2], |
那么 \(\hat{A}\) 可以表示为:
1 | edge_index = torch.tensor([[0, 1, 1, 2, 0, 1, 2], |
第 2 步 x = self.lin(x)
是使用在__init__()
中定义的线性变换函数进行变换。
最后调用self.propagate()
,在self.propagate()
中会调用self.message()
函数。
message
1 | def message(self, x_j, edge_index, size): |
message()
函数入参x_j
的形状是 [E, out_channels],其中 E 表示边的数量。由上面可知,特征矩阵经过线性变换后的输出形状是 (N, out_channels),边的矩阵的形状为 [2, E]。row, col = edge_index
表示取出所有边的起始节点和目标节点,row
表示边的起始节点的结合,col
表示边的目标节点的集合。在无向图中,这两者是等价的。以target
节点作为索引,从线性变换后的特征矩阵中索引得到target
节点的特征矩阵x_j
,示意图如下:
所以x_j
的形状是 [E, out_channels]。
deg = degree(row, size[0], dtype=x_j.dtype)
是计算每个节点的度,deg
的形状为 [E, ]。
deg_inv_sqrt = deg.pow(-0.5)
是把每个节点的度开根号。
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
是把每一条边的source
节点的度的开根号,和target
节点的度的开根号,相乘,得到每个节点的标准化系数,对应于 \(\frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}}\)。函数最后返回的是每一条边的标准化系数 × 这条边target
节点特征,形状是 [E, out_channels]。
最后调用self.aggregate()
对邻居节点特征进行聚合操作。
把每个节点以及邻接节点的特征向量进行聚合,也就是按照source
节点进行聚合,聚合操作有 sum (相加)、mean(取均值)、max(取最大值)。
这里有3条边的source
都是节点0,因此将这三行向量聚合(这里用相加),最终得到一个形状为(N, out_channels)
的特征矩阵,就是这一层GCN的输出。
参考