PyG
介绍
PyG(PyTorch Geometric)是一个基于PyTorch开发的图神经网络(GNN)专用库,旨在简化图结构数据的深度学习任务,能够与PyTorch无缝兼容。
PyG提供很多种GNN层,包括:GCNConv(图卷积)、GATConv(图注意力)、SAGEConv(GraphSAGE)、TransformerConv(图Transformer)等。
PyG内置Cora、Pubmed等经典数据集,支持一键下载与预处理。
安装与使用
1 | |
图的数据表示
创建
PyG要求数据以Data对象传递,Data是PyG中表示单张图数据的容器类,继承自torch.Tensor,专门用于存储图结构数据。
Data的核心属性:
必需属性:
属性名 数据类型 维度要求 说明 xtorch.Tensor[num_nodes, num_node_features]节点特征矩阵(若没有节点特征,可不设置) edge_indextorch.LongTensor[2, num_edges]边的连接关系(COO 格式,第0行为源节点,第1行为目标节点) 可选属性:
属性名 数据类型 说明 edge_attrtorch.Tensor边特征矩阵,维度为 [num_edges, num_edge_features]ytorch.Tensor标签(节点级标签: [num_nodes];图级标签:[1])postorch.Tensor节点坐标(如 3D 点云数据),维度为 [num_nodes, 3]batchtorch.LongTensor批处理索引(用于区分多张图合并后的节点归属) 自定义属性名 可以添加任意自定义的属性
比如:
1 | |
验证
数据验证:使用.validate()方法检查数据格式是否正确。
1 | |
实用方法
设备迁移
1
2
3
4# 将数据迁移到GPU
data = data.to('cuda:0')
# 检查设备
print(data.x.device)属性操作
1
2
3
4# 添加自定义属性
data.custom_attr = "This is a test graph"
# 查看所有属性
print(data) # 输出: Data(x=[4,3], edge_index=[2,3], y=[4], custom_attr='...')图结构分析
方法 说明 data.num_nodes返回节点数(若未显式设置 x,需通过edge_index推断)data.num_edges返回边数(等于 edge_index.shape[1])data.is_directed()判断是否为有向图(若存在反向边则为无向图)
异构图的创建
异构图需要使用HeteroData:
1 | |
定义节点类型与特征:
1 | |
添加边索引:使用三元组(源节点类型,边类型,目标节点类型)定义边,并分配COO格式的边索引(形状为[2, num_edges])
1 | |
添加边特征(可选):可以为边添加特征
1 | |
验证异构图结构:打印HeteroData对象查看结构信息
1 | |
输出示例:
1 | |
检查节点和边的数量:
1 | |
转换为GPU张量(可选):
1 | |
注意:如果某类节点无特征,可以不设置.x属性,但某些模型可能需要初始化(例如使用全零张量)
模型构建
GCN
PyG中的GCNConv是GCN的核心层,
1 | |
CGNConv()内部不包含激活函数,需手动添加。
定义一个两层的GCN:
1 | |
RGCN
PyG中的RGCNConv是RGCN的核心层:
1 | |
num_bases:基分解,通过共享基矩阵\(\mathbf{V}_b\)组合生成各种关系的\(\mathbf{W}_r\): \[ \mathbf{W}_r = \sum_{b=1}^B a_{rb}\mathbf{V}_b \] 适用场景:关系间存在潜在共性(如社交网络中不同互动类型有相似模式)num_blocks:块对角分解,将权重矩阵拆分为块对角结构,减少稠密参数 \[ \mathbf{W}_r = \mathrm{diag}(\mathbf{Q}_{1r},\cdots,\mathbf{Q}_{Br}) \] 适用场景:关系类型差异较大,需保持参数独立性
RGCNConv()对象输入数据格式:
- 节点特征矩阵
x,维度为[num_nodes, in_channels](如果是异构图,有不同类型的节点,需要将它们合并) - 边索引
edge_index,维度为[2, num_edges](如果是异构图,有不同类型的边,需要将它们合并) - 边类型
edge_type,维度为[num_edges](值范围0到num_relations-1,第\(i\)个值代表上面边索引edge_index中第\(i\)条边的类型)
由此可以看出,RGCNConv()仅区分边类型,不区分节点类型。
RGCNConv()内部并不包含激活函数,需手动添加。
示例:
假设一个社交网络,包含3个用户,两种关秀(关注(类型0)、好友(类型1))
1 | |