前言

去年上 COMP4222 学习图同构网络(Graph Isomorphism Network)的笔记。
主要参考了 CS224W 以及原论文How Powerful are Graph Neural Networks?
WL-Test 威力强大不可不品尝(应该是)

TLDR

GIN 是一种能够达到与 Weisfeiler-Lehman(WL)测试等价表达能力的图神经网络。
GIN 通过精心设计的聚合函数,解决了图神经网络在表达能力上的瓶颈,使其能够区分大多数非同构图结构。
GIN 的核心是将节点自身特征与邻居特征进行”注入”式聚合,整个模型可表示为一系列 GIN 层的堆叠。

介绍

GIN 这份工作力求探究 GNN 理论上的表达能力。这里“表达能力”的定义类似于光学里的“分辨率”,也就是模型区分不同输入数据的能力边界。一个模型的表达能力越高,就越能捕获到图的细微结构差异,进而学习出一个更有区分度的 representation/feature/embedding. 对于两个图 $G_1, G_2$,假设模型把它们映射到两个不同的向量上我们就说模型区分了这两个图。

Weisfeiler-Lehman 算法

一种启发式算法,简称 WL-test,用来判断两个图之间是否存在同构。1-WL test 又称 Color Refinement 算法,表现形式就是不断迭代地给一个图里面的每个顶点上色,在某个时刻得到一个 stable colouring,稳定的上色。

1-WL test

对于一个图输入 $G(V,E)$:

  • 初始化: 为每一个顶点赋上相同的颜色 $\lambda_0$ 这个“颜色”可以是标签(label),度数,feature等表示;
  • 迭代:对每一个顶点 $v$ 的下一个颜色,我们参考这样一个等式:$$ \lambda_{n+1}(v) = (\lambda_{n}(v), { { \lambda_{n}(u) | u \in N(v)}})$$ 也就是一个颜色对(pair), 左边的元素是上一个颜色,右边的元素是所有邻居颜色的 multiset(多重集)。 根据多重集的定义我们可以严谨地判断两个节点是否会在下一次迭代中变成相同的颜色。

两个节点$u,v$ 的 $\lambda_{n+1}$ 是不同的,当且仅当:

  1. 它俩之前的颜色不同 $ \lambda_n(u) \neq \lambda_n(v) $;
  2. 存在颜色 $c$ 使得$u,v$ 有不同数量的$c$色邻居;

实际操作中我们往往会使用某些函数处理等号右边这一大坨信息,把它们映射到颜色空间内。

$$ \lambda_{n+1}(v) = \phi(\lambda_{n}(v), f({ { \lambda_{n}(u) | u \in N(v)}}))$$
$f$ 把多重集映射成某种表示(比如一个颜色), $\phi$ 是一个单射(injective) 把这个颜色对映射到一个颜色上。

怎么用来判断图同构?

对于两个图,只要我们同时分别对它们进行1-WL算法上色,直到两位各自稳定下来,把图里全部的点打包成一个多重集,比较个颜色的频次就可以判断两张图是否同构辣。

选取这个颜色的关键是:为每个节点生成一个唯一的、结构相关的表示,以反映其局部邻居结构。这是什么?这就是Representaion Learning!

WL test这种从邻居获取聚合信息的迭代算法可太符合GNN的世界观了,所以大伙纷纷表示这就是Message-Passing范式,并打算将其和NN进行有机的结合。

注意到我们在每一轮对节点更新的时候其实就是在接受一阶邻居传递过来的消息,于是我们用嵌入向量充当“颜色”,使用某些算子从旧标签和邻居多重集中产生新颜色。这下看懂了,这就是MPNN。

$$
h_v^{(k)} = COMBINE^{(k)}(h_v^{(k-1)}, AGGREGATE^{(k)}({h_u^{(k-1)} : u ∈ N(v)}))$$

$AGG$ 聚合函数的选择可以是MAX(MIN), SUM, MEAN等,这些排列组合就是我们之前见过的GNN, MPNN, GraphSAGE 等各种知名模型。
原文在这里采用了我们介绍Color refinement 时候的写法,写成$\phi , f$

$$
h_v^{(k)} = \phi^{(k)}(h_v^{(k-1)}, f^{(k)}({h_u^{(k-1)} : u ∈ N(v)}))
$$

GIN 结构

作者证明了一种简化上面Message Passing的方法,使得这个新结构能够学习(模仿)到理想中的$\phi , f$
更新后的节点表示如下:
$$
h_v^{(k)} = MLP^{(k)}( (1 + \epsilon^{(k)})\cdot h_v^{(k-1)}, \sum_{u \in N(v)}h_u^{(k-1)})
$$
这个 $\epsilon$ 可以是一个可学习参数,也可以是一个固定的标量值,使得模型能够调整中心节点与邻居节点的重要性权重。这样一来我们只要对邻居使用 $SUM$ 聚合。剩下的事情?MLP会出手。

采用梯度下降法学习$\epsilon$ 的GIN叫 GIN-$\epsilon$。如果是固定不变的值,比如0,就叫GIN-0。作者试验发现这两种变种在各种指标上各有千秋,但 GIN-$\epsilon$ “slightly but consistently outperforming GIN-0 in terms of test accuracy”。一个有趣的发现!这是为什么呢。

表达能力比拼

GIN最大的贡献在于理论上证明了:

  1. 只有求和聚合函数能够区分不同的多重集合
  2. GIN的表达能力与WL测试等价,能区分更多的正则图结构
  3. 其他GNN变体(如GCN、GraphSAGE)的表达能力严格弱于GIN

所以CS224W认为GIN是理论上的表达能力最强网络。

作者比较了之前的几个著名网络 GCN, GraphSAGE等,

然后放了一些数学证明,这里没有给出,因为当时期末不考(

  • GCN、GraphSAGE:使用平均或最大池化聚合,表达能力严格弱于1-WL测试
  • GIN:使用求和聚合,表达能力等价于1-WL测试,是消息传递框架下理论上最强大的模型
  • 高阶GNN:如PPGN、k-GNN等,能够达到k-WL测试的表达能力

提升表达能力的方法有很多:采样高阶邻居,子图采样,把各种先进时髦的结构,比如注意力机制等狠狠地融合…

图级表示学习

GIN最常用的领域是学习一整张图的表示,它的做法是:

  1. 通过多层GIN计算各层节点表示
  2. 对每一层的节点表示进行READOUT(如求和或最大池化)得到图级表示
  3. 将各层图表示拼接或求和,得到最终图表示
  4. 将最终图表示输入分类器进行任务预测

$$
h_G = \text{CONCAT}( \text{READOUT}( {h_v^{(k)} | v \in G}) | k = 0,1,…,K)
$$

实验数据集

原文主要在一下几个领域做的实验(AI总结的):
分子图分类(MUTAG, PTC, NCI1等)
社交网络分类(IMDB-BINARY, REDDIT-BINARY等)
蛋白质结构分析(PROTEINS)

LLM 写的实现示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GINConv, global_add_pool

class GIN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(GIN, self).__init__()

# 初始化GIN层
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()

# 第一层
nn = Sequential(
Linear(in_channels, hidden_channels),
ReLU(),
Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(nn, train_eps=True))
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))

# 中间层
for _ in range(num_layers - 2):
nn = Sequential(
Linear(hidden_channels, hidden_channels),
ReLU(),
Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(nn, train_eps=True))
self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))

# 最终的图级预测层
self.lin = Linear(num_layers * hidden_channels, out_channels)

def forward(self, x, edge_index, batch):
# 保存每层的图表示
hidden_states = []

for conv, batch_norm in zip(self.convs, self.batch_norms):
x = conv(x, edge_index)
x = batch_norm(x)
x = torch.relu(x)
# 池化得到图表示
hidden_states.append(global_add_pool(x, batch))

# 拼接所有层的图表示
h_graph = torch.cat(hidden_states, dim=1)

# 图级预测
out = self.lin(h_graph)

return out