PyG进阶

MassagePassing

引入

MassagePassing是PyG(PyTorch Geometric)库中的一个核心类,它为实现基于消息传递机制的图神经网络(GNN)提供了一个通用的框架。

消息传递机制是图神经网络的基础,其核心思想是通过节点间的消息传递和聚合来更新节点的特征表示。具体步骤如下:

  1. 消息生成(Message Generation):每个节点会根据自身的特征和其邻居节点的特征生成消息
  2. 消息聚合(Message Aggregation):每个节点会将其邻居节点传递过来的消息进行聚合,比如求和、求平均等
  3. 节点更新(Node Update):每个节点会根据聚合后的消息和自身的特征进行更新,得到新的节点特征

在PyG中使用MessagePassing类实现自定义的图神经网络层,通常需要以下步骤:

  1. 继承MessagePassing
  2. 初始化类,设置消息传递的方向和聚合方式
  3. 实现message方法,用于消息生成
  4. 实现update方法,用于更新节点特征

比如,使用MessagePassing类实现GCN:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # 使用求和聚合方式
self.lin = nn.Linear(in_channels, out_channels)

def forward(self, x, edge_index):
# 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

# 线性变换
x = self.lin(x) # XW

# 计算归一化系数
row, col = edge_index # 提取源节点row和目标节点col
deg = degree(col, x.size(0), dtype=x.dtype) # 计算度矩阵,但是存储形式是每个节点的度,维度为[num_nodes]
deg_inv_sqrt = deg.pow(-0.5) # D^(-1/2),维度同样是[num_nodes]
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # 归一化 D^(-1/2)AD^(-1/2)
# 上面的norm是针对每条边而言的,对应的是每条边的归一化

# 开始消息传递
return self.propagate(edge_index, x=x, norm=norm)

def message(self, x_j, norm):
# 生成消息
return norm.view(-1, 1) * x_j

def update(self, aggr_out):
# 更新节点特征
return aggr_out

其中,add_self_loops()函数的作用是给节点添加自环,未添加自环的edge_index维度为[2, num_edges],添加自环后为[2, num_edges_new](即在其中添加了自己连向自己的边)。

rowcow的维度均为[num_edges_new],其中row代表所有边中的源节点,cow代表所有边中的目标节点

degree()函数用于计算每个节点的度,返回的是[num_nodes]维度的数据,第\(i\)个数据代表第\(i\)个节点的度,因此deg的维度为[num_nodes]

deg_inv_sqrtdeg的负平方根,维度同样是[num_nodes]

norm是归一化系数,对应的每一条边的归一化系数,GCN公式为: \[ \mathbf{h}_i^{(l+1)} = \sigma\left( \mathbf{h}_i^{(l)}\mathbf{W}_0^{(l)} + \sum_{j\in\mathbf{\mathcal{N}(i)}}\frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}}\mathbf{h}_j^{(l)}\mathbf{W}^{(l)} \right) \] 那么节点\(i\)在聚合节点\(j\)信息的时候使用的归一化系数就是 \[ \frac{1}{\sqrt{\mathcal{N}(i)}\sqrt{\mathcal{N}(j)}} \]deg_inv_sqrt[row]就是每个源节点\(i\)对应的\(\frac{1}{\sqrt{\mathcal{N}(i)}}\),同理,deg_inv_sqrt[cal]就是每个目标节点\(j\)对应的\(\frac{1}{\sqrt{\mathcal{N}(j)}}\)

因此,norm中存储的就是各个源节点和目标节点对应的归一化系数,维度为[num_edges_new]

self.propagate()会把传入的参数传递给messageaggregateupdate方法,实现消息传递

self.propagate()需要传入的参数:

  • 必选参数:

    • edge_index:维度[2, num_edges]

    • x:节点的特征张量,形状为[num_nodes, feat_dim]

  • 可选参数:

    • norm:归一化系数,维度[num_edges]

    • 自定义参数:比如节点的类别信息、边的权重等等:

      1
      self.propagate(edge_index, x=x, norm=norm, edge_weight=edge_weight)

message()函数负责生成从邻居节点传递到中心节点的消息,利用边的属性、归一化系数等等,生成要传递的消息,这些消息后续会被聚合起来,用于更新中心节点的特征。

message()函数需要传入的参数:

  • 必选参数:
    • x_j:每条边源节点的特征矩阵,维度为[num_edges, feat_dim]
  • 可选参数:
    • x_i:每条边目标节点的特征矩阵,维度为[num_edges, feat_dim]
    • 其他自定义参数,比如:
      • norm:归一化系数,维度为[num_edges]
      • edge_weight:边权重

message()函数的传入顺序:x_jx_i放在第一位和第二位,如果不需要x_i,可以省略,剩下的参数的传入顺序与self.propagate()参数传入顺序相同。

比如:

1
2
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

或:

1
2
def message(self, x_j, x_i, norm):
return norm.view(-1, 1) * (x_j + x_i)

其中norm.view(-1, 1)的维度为[num_edges, 1],用于归一化x_j特征

update()函数是对聚合后的消息进行处理,进而更新节点特征

update()函数需要传入的参数

  • aggr_out:是aggregate函数聚合后的结果,aggr_out的维度通常为[num_nodes, out_channes]num_nodes是图中节点的数量,out_channels是聚合后消息的特征维度。aggr_out一般作为第一个参数
  • 自定义参数,顺序与self.propagate()传递顺序相同

比如:

1
2
3
def update(self, aggr_out):
return aggr_out
# 如果是上面这种代码,可以省略update函数不写,因为这就是默认的update方法

也可以加入非线性变换等操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class ComplexGCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.lin1 = nn.Linear(in_channels, out_channels)
self.lin2 = nn.Linear(out_channels, out_channels)
self.relu = nn.ReLU()

def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin1(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)

def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

def update(self, aggr_out):
# 应用非线性变换
out = self.linear2(combined)
out = self.relu(out)
return out

示例

自定义一个LightGCN模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from torch_gepmetric.nn import MessagePassing
from torch_geometric.utils import degree

class LightGCNConv(MessagePassing):
def __init__(self):
super().__init__(aggr='add')

def forward(self, x, edge_index):
# 计算归一化系数
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt = deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

return self.propagate(edge_index, x=x, norm=norm)

def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

def update(self, aggr_out):
return aggr_out


class LightGCNRecommender(nn.Module):
def __init__(self, num_users, num_items, hidden_dim, dropout=0.1):
super().__init__()

# 用户与商品的嵌入层
self.user_emb = nn.Embedding(num_users, hidden_dim)
self.item_emb = nn.Embedding(num_items, hidden_dim)

# 定义3层LightGCN
self.conv1 = LightGCNConv()
self.conv2 = LightGCNConv()
self.conv3 = LightGCNConv()

# 初始化参数
self._init_weights()

def _init_weights(self):
nn.init.normal_(self.user_emb.weight, std=0.1)
nn.init.normal_(self.item_emb.weight, std=0.1)

def forward(self, data):
# 处理输入特征
user_x = self.user_emb(data['user'].x.squeeze())
item_x = self.item_emb(data['item'].x.squeeze())

# 拼接
x = torch.cat([user_x, item_x], dim=0)

# 三层传播
x0 = x
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)

# 平均各层嵌入
x = (x0 + x1 + x2 + x3) / 4
return x

def predict(self, user_idx, item_idx, data):
x = self.forward(data)
user_emb = x[user_idx]
item_emb = x[item_idx]
return (user_emb * item_emb).sum(dim=-1)

常用方法

degree()

degree()函数主要用于计算图中节点的度,用法:

1
torch_geometric.utils.degree(index, num_nodes=None, dtype=None)
  • index:是一个一维的torch.Tensor,其元素为边的目标节点索引。在有向图里,index通常代表边的目标节点索引;在无向图中,index可以代表任意一端节点的索引
  • num_nodes:可选参数,是一个整数,代表图中节点的总数,若不指定该参数,函数会将index中最大的值加1作为节点的总数
  • dtype:同样是可选参数,指定返回张量的数据类型,若不指定,会使用index的数据类型

返回值:一个一维的torch.Tensor,其长度等于节点总数,每个元素代表对应节点的度

add_self_loops()

add_self_loops()函数用于为图添加自环(self-loops),用法:

1
torch_geometric.utils.add_self_loops(edge_index, edge_weights=None, fill_value=1, num_nodes=None)
  • edge_index:是一个维度为[2, num_edges]的二维torch.Tensor,用于表示图中边的连接关系。其中第一行存储源节点的索引,第二行存储目标节点的索引
  • edge_weights:可选参数,形状为[num_edges]的一维torch.Tensor,用于表示每条边的权重,如果提供了该参数,函数会在添加自环时相应地处理边的权重
  • fill_value:同样是可选参数,用于指定添加的自环的边的权重值,默认为1
  • num_nodes:可选参数,是一个整数,代表图中节点的总数,如果不指定该参数,函数会根据edge_index中最大节点的索引加1来确定节点数

返回值:返回一个元组(edge_index_new, edge_weight_new)

  • edge_index_new:是一个形状为[2, num_edges_new]的二维torch.Tensor,表示添加自环后的边索引,其中num_edges_new是添加自环后边的总数
  • edge_weight_new:是一个形状为[num_edges_new]的一维torch.Tensor,表示添加自环后边的权重。如果edge_weightNone,则返回一个None

PyG进阶
https://blog.shinebook.net/2025/05/07/人工智能/pytorch/PyG进阶/
作者
X
发布于
2025年5月7日
许可协议