PyG进阶
MassagePassing
引入
MassagePassing
是PyG(PyTorch
Geometric)库中的一个核心类,它为实现基于消息传递机制的图神经网络(GNN)提供了一个通用的框架。
消息传递机制是图神经网络的基础,其核心思想是通过节点间的消息传递和聚合来更新节点的特征表示。具体步骤如下:
- 消息生成(Message Generation):每个节点会根据自身的特征和其邻居节点的特征生成消息
- 消息聚合(Message Aggregation):每个节点会将其邻居节点传递过来的消息进行聚合,比如求和、求平均等
- 节点更新(Node Update):每个节点会根据聚合后的消息和自身的特征进行更新,得到新的节点特征
在PyG中使用MessagePassing
类实现自定义的图神经网络层,通常需要以下步骤:
- 继承
MessagePassing
类 - 初始化类,设置消息传递的方向和聚合方式
- 实现
message
方法,用于消息生成 - 实现
update
方法,用于更新节点特征
比如,使用MessagePassing
类实现GCN:
1 |
|
其中,add_self_loops()
函数的作用是给节点添加自环,未添加自环的edge_index
维度为[2, num_edges]
,添加自环后为[2, num_edges_new]
(即在其中添加了自己连向自己的边)。
row
和cow
的维度均为[num_edges_new]
,其中row
代表所有边中的源节点,cow
代表所有边中的目标节点
degree()
函数用于计算每个节点的度,返回的是[num_nodes]
维度的数据,第\(i\)个数据代表第\(i\)个节点的度,因此deg
的维度为[num_nodes]
deg_inv_sqrt
是deg
的负平方根,维度同样是[num_nodes]
norm
是归一化系数,对应的每一条边的归一化系数,GCN公式为:
\[
\mathbf{h}_i^{(l+1)} = \sigma\left(
\mathbf{h}_i^{(l)}\mathbf{W}_0^{(l)} +
\sum_{j\in\mathbf{\mathcal{N}(i)}}\frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}}\mathbf{h}_j^{(l)}\mathbf{W}^{(l)}
\right)
\] 那么节点\(i\)在聚合节点\(j\)信息的时候使用的归一化系数就是 \[
\frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}}
\] 而deg_inv_sqrt[row]
就是每个源节点\(i\)对应的\(\frac{1}{\sqrt{\mathcal{N}(i)}}\),同理,deg_inv_sqrt[cal]
就是每个目标节点\(j\)对应的\(\frac{1}{\sqrt{\mathcal{N}(j)}}\)
因此,norm
中存储的就是各个源节点和目标节点对应的归一化系数,维度为[num_edges_new]
self.propagate()
会把传入的参数传递给message
、aggregate
、update
方法,实现消息传递
self.propagate()
需要传入的参数:
必选参数:
edge_index
:维度[2, num_edges]
x
:节点的特征张量,形状为[num_nodes, feat_dim]
可选参数:
norm
:归一化系数,维度[num_edges]
自定义参数:比如节点的类别信息、边的权重等等:
1
self.propagate(edge_index, x=x, norm=norm, edge_weight=edge_weight)
message()
函数负责生成从邻居节点传递到中心节点的消息,利用边的属性、归一化系数等等,生成要传递的消息,这些消息后续会被聚合起来,用于更新中心节点的特征。
message()
函数需要传入的参数:
- 必选参数:
x_j
:每条边源节点的特征矩阵,维度为[num_edges, feat_dim]
- 可选参数:
x_i
:每条边目标节点的特征矩阵,维度为[num_edges, feat_dim]
- 其他自定义参数,比如:
norm
:归一化系数,维度为[num_edges]
edge_weight
:边权重
message()
函数的传入顺序:x_j
、x_i
放在第一位和第二位,如果不需要x_i
,可以省略,剩下的参数的传入顺序与self.propagate()
参数传入顺序相同。
比如:
1 |
|
或:
1 |
|
其中norm.view(-1, 1)
的维度为[num_edges, 1]
,用于归一化x_j
特征
update()
函数是对聚合后的消息进行处理,进而更新节点特征。
update()
函数需要传入的参数:
aggr_out
:是aggregate
函数聚合后的结果,aggr_out
的维度通常为[num_nodes, out_channes]
,num_nodes
是图中节点的数量,out_channels
是聚合后消息的特征维度。aggr_out
一般作为第一个参数- 自定义参数,顺序与
self.propagate()
传递顺序相同
比如:
1 |
|
也可以加入非线性变换等操作:
1 |
|
示例
自定义一个LightGCN模型:
1 |
|
常用方法
degree()
degree()
函数主要用于计算图中节点的度,用法:
1 |
|
index
:是一个一维的torch.Tensor
,其元素为边的目标节点索引。在有向图里,index
通常代表边的目标节点索引;在无向图中,index
可以代表任意一端节点的索引num_nodes
:可选参数,是一个整数,代表图中节点的总数,若不指定该参数,函数会将index
中最大的值加1作为节点的总数dtype
:同样是可选参数,指定返回张量的数据类型,若不指定,会使用index
的数据类型
返回值:一个一维的torch.Tensor
,其长度等于节点总数,每个元素代表对应节点的度
add_self_loops()
add_self_loops()
函数用于为图添加自环(self-loops),用法:
1 |
|
edge_index
:是一个维度为[2, num_edges]
的二维torch.Tensor
,用于表示图中边的连接关系。其中第一行存储源节点的索引,第二行存储目标节点的索引edge_weights
:可选参数,形状为[num_edges]
的一维torch.Tensor
,用于表示每条边的权重,如果提供了该参数,函数会在添加自环时相应地处理边的权重fill_value
:同样是可选参数,用于指定添加的自环的边的权重值,默认为1num_nodes
:可选参数,是一个整数,代表图中节点的总数,如果不指定该参数,函数会根据edge_index
中最大节点的索引加1来确定节点数
返回值:返回一个元组(edge_index_new, edge_weight_new)
:
edge_index_new
:是一个形状为[2, num_edges_new]
的二维torch.Tensor
,表示添加自环后的边索引,其中num_edges_new
是添加自环后边的总数edge_weight_new
:是一个形状为[num_edges_new]
的一维torch.Tensor
,表示添加自环后边的权重。如果edge_weight
为None
,则返回一个None