PyG进阶
MassagePassing
引入
MassagePassing是PyG(PyTorch
Geometric)库中的一个核心类,它为实现基于消息传递机制的图神经网络(GNN)提供了一个通用的框架。
消息传递机制是图神经网络的基础,其核心思想是通过节点间的消息传递和聚合来更新节点的特征表示。具体步骤如下:
- 消息生成(Message Generation):每个节点会根据自身的特征和其邻居节点的特征生成消息
- 消息聚合(Message Aggregation):每个节点会将其邻居节点传递过来的消息进行聚合,比如求和、求平均等
- 节点更新(Node Update):每个节点会根据聚合后的消息和自身的特征进行更新,得到新的节点特征
在PyG中使用MessagePassing类实现自定义的图神经网络层,通常需要以下步骤:
- 继承
MessagePassing类 - 初始化类,设置消息传递的方向和聚合方式
- 实现
message方法,用于消息生成 - 实现
update方法,用于更新节点特征
比如,使用MessagePassing类实现GCN:
1 | |
其中,add_self_loops()函数的作用是给节点添加自环,未添加自环的edge_index维度为[2, num_edges],添加自环后为[2, num_edges_new](即在其中添加了自己连向自己的边)。
row和cow的维度均为[num_edges_new],其中row代表所有边中的源节点,cow代表所有边中的目标节点
degree()函数用于计算每个节点的度,返回的是[num_nodes]维度的数据,第\(i\)个数据代表第\(i\)个节点的度,因此deg的维度为[num_nodes]
deg_inv_sqrt是deg的负平方根,维度同样是[num_nodes]
norm是归一化系数,对应的每一条边的归一化系数,GCN公式为:
\[
\mathbf{h}_i^{(l+1)} = \sigma\left(
\mathbf{h}_i^{(l)}\mathbf{W}_0^{(l)} +
\sum_{j\in\mathbf{\mathcal{N}(i)}}\frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}}\mathbf{h}_j^{(l)}\mathbf{W}^{(l)}
\right)
\] 那么节点\(i\)在聚合节点\(j\)信息的时候使用的归一化系数就是 \[
\frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}}
\] 而deg_inv_sqrt[row]就是每个源节点\(i\)对应的\(\frac{1}{\sqrt{\mathcal{N}(i)}}\),同理,deg_inv_sqrt[cal]就是每个目标节点\(j\)对应的\(\frac{1}{\sqrt{\mathcal{N}(j)}}\)
因此,norm中存储的就是各个源节点和目标节点对应的归一化系数,维度为[num_edges_new]
self.propagate()会把传入的参数传递给message、aggregate、update方法,实现消息传递
self.propagate()需要传入的参数:
必选参数:
edge_index:维度[2, num_edges]x:节点的特征张量,形状为[num_nodes, feat_dim]
可选参数:
norm:归一化系数,维度[num_edges]自定义参数:比如节点的类别信息、边的权重等等:
1
self.propagate(edge_index, x=x, norm=norm, edge_weight=edge_weight)
message()函数负责生成从邻居节点传递到中心节点的消息,利用边的属性、归一化系数等等,生成要传递的消息,这些消息后续会被聚合起来,用于更新中心节点的特征。
message()函数需要传入的参数:
- 必选参数:
x_j:每条边源节点的特征矩阵,维度为[num_edges, feat_dim]
- 可选参数:
x_i:每条边目标节点的特征矩阵,维度为[num_edges, feat_dim]- 其他自定义参数,比如:
norm:归一化系数,维度为[num_edges]edge_weight:边权重
message()函数的传入顺序:x_j、x_i放在第一位和第二位,如果不需要x_i,可以省略,剩下的参数的传入顺序与self.propagate()参数传入顺序相同。
比如:
1 | |
或:
1 | |
其中norm.view(-1, 1)的维度为[num_edges, 1],用于归一化x_j特征
update()函数是对聚合后的消息进行处理,进而更新节点特征。
update()函数需要传入的参数:
aggr_out:是aggregate函数聚合后的结果,aggr_out的维度通常为[num_nodes, out_channes],num_nodes是图中节点的数量,out_channels是聚合后消息的特征维度。aggr_out一般作为第一个参数- 自定义参数,顺序与
self.propagate()传递顺序相同
比如:
1 | |
也可以加入非线性变换等操作:
1 | |
示例
自定义一个LightGCN模型:
1 | |
常用方法
degree()
degree()函数主要用于计算图中节点的度,用法:
1 | |
index:是一个一维的torch.Tensor,其元素为边的目标节点索引。在有向图里,index通常代表边的目标节点索引;在无向图中,index可以代表任意一端节点的索引num_nodes:可选参数,是一个整数,代表图中节点的总数,若不指定该参数,函数会将index中最大的值加1作为节点的总数dtype:同样是可选参数,指定返回张量的数据类型,若不指定,会使用index的数据类型
返回值:一个一维的torch.Tensor,其长度等于节点总数,每个元素代表对应节点的度
add_self_loops()
add_self_loops()函数用于为图添加自环(self-loops),用法:
1 | |
edge_index:是一个维度为[2, num_edges]的二维torch.Tensor,用于表示图中边的连接关系。其中第一行存储源节点的索引,第二行存储目标节点的索引edge_weights:可选参数,形状为[num_edges]的一维torch.Tensor,用于表示每条边的权重,如果提供了该参数,函数会在添加自环时相应地处理边的权重fill_value:同样是可选参数,用于指定添加的自环的边的权重值,默认为1num_nodes:可选参数,是一个整数,代表图中节点的总数,如果不指定该参数,函数会根据edge_index中最大节点的索引加1来确定节点数
返回值:返回一个元组(edge_index_new, edge_weight_new):
edge_index_new:是一个形状为[2, num_edges_new]的二维torch.Tensor,表示添加自环后的边索引,其中num_edges_new是添加自环后边的总数edge_weight_new:是一个形状为[num_edges_new]的一维torch.Tensor,表示添加自环后边的权重。如果edge_weight为None,则返回一个None