GCN 从理论到源码分析

GCN 从理论到源码分析

GNN 简介

在 CNN 中,模型的输入数据一般都是固定宽高的数据,比如图片,这类数据在形状上是规则的,称为欧几里得结构化数‘据。而许多重要的现实世界数据集以图表或网络的形式出现:社会网络、知识图表、蛋白质相互作用网络、万维网等,这种数据是不规则的,没有固定形状,称为非欧几里得结构化数据。虽然 CNN 取得了巨大的成功,然而却不能直接把 CNN 应用到这种数据上。近年来一些研究致力于把 RNN 和 CNN 推广到非欧几里得结构化数据,出现了很多新的算法。包括谱卷积和空间卷积,其中Defferrard et al. (NIPS 2016) 基于类神经网络模型学习的自由参数Chebyshev多项式的谱域近似平滑滤波器在 MNIST 数据集上取得了很好的效果。而基于谱卷积 (spectral graph convolutions),衍生出了Kipf & Welling (ICLR 2017),不仅运算更加快速,而且在一些基准数据集上的准确率也更高。

GCN 参数定义

大多数 GNN (graph neural network) 都有一个通用的体系结构,我们把这些称为 GCN (graph convolutional network),因为参数在整个图或者部分图上都是共享的,因此称为卷积 (convolutional)。这些模型的目标是学习一个函数,可以对图上的特征进行建模,图定义为 \(\mathcal{G}=(\mathcal{V}, \mathcal{E})\),其中 \(\mathcal{V}\) 表示节点, \(\mathcal{E}\) 表示边。输入包括两部分

  • 特征矩阵 \(X\),形状是 \(N \times D\)\(N\) 表示节点的数量,\(D\) 表示每个节点特征的数量。每个节点的特征向量表示为 \(x_{i}\)
  • 表示图结构的矩阵,一般是图的邻接矩阵 \(A\)

输出有两种形式,一种是节点级别的输出特征矩阵 \(Z\),形状是 \(N \times F\)\(N\) 表示节点的数量,\(F\) 表示输出的每个节点特征的数量;另一种是图级别的输出标量 \(z\),可以通过引入某种形式的池化操作来得到。因此每一层 GCN 可以表示如下:

\[ \begin{align}X^{(l+1)}=f\left(X^{(l)}, A\right) \end{align}\]

其中 \(X^{(l)}\) 表示第 \(l\) 层输入的特征矩阵,\(X^{(l+1)}\) 表示第 \(l\) 层输出的特征矩阵,\(f(\cdot, \cdot)\) 表示一种映射或者变换方法。不同的 \(f(\cdot, \cdot)\) 形成不同的模型。我们主要研究的就是 \(f(\cdot, \cdot)\) 的选择和参数。

GCN 的一个例子

我们把上面一般形式的 GCN 表达式具体化 :

\[f\left(X^{(l)}, A\right)=\sigma\left(A X^{(l)} W^{(l)}\right)\]

其中 \(W^{(l)}\) 是第 \(l\) 层的权重矩阵,\(\sigma(\cdot)\) 是某种非线性激活函数,如 \(ReLU\),虽然这个表达式很简单,但是却很强大。但首先让我们来看下这个公式存在的 2 个问题和处理方法。

  • 第一个问题是邻接矩阵 \(A\) 的对角线上的元素全为 0,表示没有自己到自己的边。所以邻接矩阵 \(A\) 与特征矩阵 \(X^{(l)}\) 相乘表示对于每个节点,把与该节点相邻节点的特征向量相加,除了该节点自身的特征向量。所以需要加上自己到自己的边,称为 self-loops,使得邻接矩阵对角线上的元素全为 1。要添加 self-loops,只需要把邻接矩阵 \(A\) 加上一个单位矩阵即可。
  • 第二个问题是邻接矩阵 \(A\) 没有标准化,因此 \(A X^{(l)} W^{(l)}\) 的结果会改变特征矩阵 \(X\) 的元素取值的数量级 (尺度),我们以可通过观察矩阵 \(A\) 的特征值来理解这一点 (特征值在某种意义上可以反应这个矩阵的大小)。将矩阵 \(A\) 标准化之后,使得每行的和为 1 就可以了,如 \(D^{-1} A\),其中 \(D\) 是对角节点度矩阵 (diagonal node degree matrix)。 \(D^{-1}A X^{(l)}\) 相当于计算相邻节点特征的平均值。实际上,当我们使用对称规范化时,特征的变换变得更有趣,如 \(D^{-\frac{1}{2}} A D^{-\frac{1}{2}}\),这种变化不再相当于计算相邻节点特征的平均值。

经过上面的分析,最终的前向传播表达式为:

\[f\left(X^{(l)}, A\right)=\sigma\left(\hat{D}^{-\frac{1}{2}} \hat{A} \hat{D}^{-\frac{1}{2}} X^{(l)} W^{(l)}\right)\]

其中 \(\hat{A}=A+I\)\(I\) 是单位矩阵。\(\hat{D}\) 是矩阵 \(\hat{A}\) 的对角节点度矩阵 (diagonal node degree matrix)。

对 karate club network 进行 Embedding

karate club network 是一个简单的图数据集,如下图,不同的颜色表示通过聚类获得的 communities。


这里我们使用 3 层的 GCN,权重矩阵 \(W\) 是随机初始化的,特征矩阵 \(X=I\),表示所有节点没有特征。3 层 GCN 在前向传播的过程中有效地卷积每个节点的 3 阶邻域。值得注意的是,该模型生成了这些节点的 Embedding,这些节点与图的 Community 结构非常相似(参见下图)。而我们的权重是随机初始化得到的,并且还没有执行反向传播更新参数,也就是说我们还没有学习模型,就已经得到了不错的结果。


这个结果有些令人惊讶。最近一篇关于关于 DeepWalk 模型的论文 (Perozzi et al., KDD 2014) 表明,他们可以在一个复杂的无监督训练过程中学习到非常相似的 Embedding。为什么使用我们简单的没有训练过的 GCN 模型,就可以得到这样一个不错的结果?这看起来像是有点违背天下没有免费的午餐的定理。

我们可以通过著名的 Weisfeiler-Lehman 算法来阐明这一点,将 GCN 模型解释为 Weisfeiler-Lehman 算法的一个广义的、可微的版本。一维 Weisfeiler-Lehman 算法的工作原理如下:

对于图上的所有点 \(v_{i} \in \mathcal{G}\)

  • 获取邻接顶点 \(\left\{v_{j}\right\}\) 的特征集合 \(\left\{x_{v_{j}}\right\}\)
  • 更新顶点 \(i\) 的特征 \(x_{v_{i}} \leftarrow \operatorname{hash}\left(\sum_{j} x_{v_{j}}\right)\),其中 \(\operatorname{hash}(\cdot)\) 理想情况下是内射散列函数

重复 \(k\) 次上述步骤或者直到收敛。

实验发现,Weisfeiler-Lehman 算法为大多数图生成了一组独特的特征矩阵,这意味着为图中每个节点生成了一个独特的特征向量,每个节点的向量相当于描述其在图中的功能。不过在高度规则的图没什么效果,如网格、链等。对于大多数不规则图,通过 Weisfeiler-Lehman 算法得到的这个特征可用于检查图是否同构(即两个图是否相同,即使两个图的节点排列不同也能检测出)。

回到 GCN 上面来,上面的一层 GCN 的表达式是矩阵形式,而向量形式如下:

\[x_{v_{i}}^{(l+1)}=\sigma\left(\sum_{j} \frac{1}{c_{i j}} x_{v_{j}}^{(l)} W^{(l)}\right)\]

其中 \(j\) 表示 \(v_{j}\) 的邻接节点,\(c_{i,j}\) 是边 \(\left(v_{i}, v_{j}\right)\) 的标准化常数,来源于经过了对称标准化的邻接矩阵 \(D^{-\frac{1}{2}} A D^{-\frac{1}{2}}\)。上式的向量化前向传播规则,可以解释为原始 Weisfeiler-Lehman 算法中使用的散列函数的可微和参数化,并且初始化随机权重矩阵,使其正交 (权重正交的意义可查看 Glorot & Bengio, AISTATS 2010)。我们观察到了很好的结果,得到了有意义的平滑 Embedding,并且可以使用 Embedding 之间的距离来衡量和解释局部图结构的相似性。

半监督学习

由于模型中的所有参数都是可微的和参数化的,我们可以添加一些标签、训练模型并观察 Embedding 的变化。我们可以使用 Kipf & Welling (ICLR 2017) 中介绍的 GCN 半监督学习算法。在下图中,我们为每个节点标记一个类或者 Community,然后开始迭代训练 300 次,得到了节点特征在 2 维空间的可视化过程。


模型直接产生了一个二维的隐空间,我们将其进行了可视化。我们观察到,3 层 GCN 模型能够线性地分离 Community,每个类只有一个标注的节点。这是一个相当好的结果,因为模型的输入没有包括所有节点的特征。

Torch geometric GCNConv 源码分析

上文中,GCN 的表达式是从整个图的角度来考虑和描述的。从单个节点来说,每个节点的特征向量可以表示为的变换 (前向传播) 的向量形式可以表示为:

\(\mathbf{x}_{i}^{(k)}=\sum_{j \in \mathcal{N}(i) \cup\{i\}} \frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}} \cdot\left(\boldsymbol{W} \cdot \mathbf{x}_{j}^{(k-1)}\right)\)

其中 \(W\) 是权重矩阵 (即模型学习过程中要更新的参数),\(\mathbf{x}_{i}^{(k)}\) 表示节点 \(i\) 在第 \(k\) 次迭代的特征向量,\(deg(i)\) 表示节点 \(i\) 的度,\(\mathcal{N}(i)\) 表示节点 \(i\) 所有邻接节点的集合。

GCNConv源码

目前最新版本的源码已经和下面的代码不一样了,但是原理基本上是一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)

def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]

# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

# Step 2: Linearly transform node feature matrix.
x = self.lin(x)

# Step 3-5: Start propagating messages.
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# edge_index has shape [2, E]

# Step 3: Normalize node features.
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype) # [N, ]
deg_inv_sqrt = deg.pow(-0.5) # [N, ]
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

return norm.view(-1, 1) * x_j

def update(self, aggr_out):
# aggr_out has shape [N, out_channels]

# Step 5: Return new node embeddings.
return aggr_out

初始化init

1
2
3
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation.
self.lin = torch.nn.Linear(in_channels, out_channels)

定义self.lin为线性变换函数,其中in_channel是每个节点输入特征的维度,out_channels是每个节点输出特征的维度。这里只是定义了结构,具体的特征维度变换逻辑是在forward()里实现的。这一部分对应着\(XW\)。输入的特征矩阵维度是(N, in_channels),输出的特征矩阵维度是(N, out_channels),其中 N 是节点个数。

forward

1
2
3
4
5
6
7
8
9
10
11
12
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]

# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

# Step 2: Linearly transform node feature matrix.
x = self.lin(x)

# Step 3-5: Start propagating messages.
return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

第 1 步 edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))是给邻接矩阵加上 self loops,也即构造出矩阵 \(\hat{A}=A+I\),在torch geometric中,邻接矩阵表示为 2 维数组,第 1 行表示边的起始节点 (source 节点),第 2 行表示边的目标节点 (target 节点)。如下图


邻接矩阵 \(A\) 可以表示为:

1
2
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]],

那么 \(\hat{A}\) 可以表示为:

1
2
edge_index = torch.tensor([[0, 1, 1, 2, 0, 1, 2],
[1, 0, 2, 1, 0, 1, 2]],

第 2 步 x = self.lin(x)是使用在__init__()中定义的线性变换函数进行变换。

最后调用self.propagate(),在self.propagate()中会调用self.message()函数。

message

1
2
3
4
5
6
7
8
9
10
11
def message(self, x_j, edge_index, size):
# x_j has shape [E, out_channels]
# edge_index has shape [2, E]

# Step 3: Normalize node features.
row, col = edge_index
deg = degree(row, size[0], dtype=x_j.dtype) # [N, ]
deg_inv_sqrt = deg.pow(-0.5) # [N, ]
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

return norm.view(-1, 1) * x_j

message()函数入参x_j 的形状是 [E, out_channels],其中 E 表示边的数量。由上面可知,特征矩阵经过线性变换后的输出形状是 (N, out_channels),边的矩阵的形状为 [2, E]。row, col = edge_index表示取出所有边的起始节点和目标节点,row表示边的起始节点的结合,col表示边的目标节点的集合。在无向图中,这两者是等价的。以target节点作为索引,从线性变换后的特征矩阵中索引得到target节点的特征矩阵x_j,示意图如下:

所以x_j 的形状是 [E, out_channels]。

deg = degree(row, size[0], dtype=x_j.dtype)是计算每个节点的度,deg的形状为 [E, ]。

deg_inv_sqrt = deg.pow(-0.5)是把每个节点的度开根号。

norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]是把每一条边的source节点的度的开根号,和target节点的度的开根号,相乘,得到每个节点的标准化系数,对应于 \(\frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}}\)。函数最后返回的是每一条边的标准化系数 × 这条边target节点特征,形状是 [E, out_channels]。

最后调用self.aggregate()对邻居节点特征进行聚合操作。

把每个节点以及邻接节点的特征向量进行聚合,也就是按照source节点进行聚合,聚合操作有 sum (相加)、mean(取均值)、max(取最大值)。

这里有3条边的source都是节点0,因此将这三行向量聚合(这里用相加),最终得到一个形状为(N, out_channels)的特征矩阵,就是这一层GCN的输出。

参考

评论