PyG
介绍
PyG(PyTorch Geometric)是一个基于PyTorch开发的图神经网络(GNN)专用库,旨在简化图结构数据的深度学习任务,能够与PyTorch无缝兼容。
PyG提供很多种GNN层,包括:GCNConv
(图卷积)、GATConv
(图注意力)、SAGEConv
(GraphSAGE)、TransformerConv
(图Transformer)等。
PyG内置Cora
、Pubmed
等经典数据集,支持一键下载与预处理。
安装与使用
1 |
|
图的数据表示
创建
PyG要求数据以Data
对象传递,Data
是PyG中表示单张图数据的容器类,继承自torch.Tensor
,专门用于存储图结构数据。
Data
的核心属性:
必需属性:
属性名 数据类型 维度要求 说明 x
torch.Tensor
[num_nodes, num_node_features]
节点特征矩阵(若没有节点特征,可不设置) edge_index
torch.LongTensor
[2, num_edges]
边的连接关系(COO 格式,第0行为源节点,第1行为目标节点) 可选属性:
属性名 数据类型 说明 edge_attr
torch.Tensor
边特征矩阵,维度为 [num_edges, num_edge_features]
y
torch.Tensor
标签(节点级标签: [num_nodes]
;图级标签:[1]
)pos
torch.Tensor
节点坐标(如 3D 点云数据),维度为 [num_nodes, 3]
batch
torch.LongTensor
批处理索引(用于区分多张图合并后的节点归属) 自定义属性名 可以添加任意自定义的属性
比如:
1 |
|
验证
数据验证:使用.validate()
方法检查数据格式是否正确。
1 |
|
实用方法
设备迁移
1
2
3
4# 将数据迁移到GPU
data = data.to('cuda:0')
# 检查设备
print(data.x.device)属性操作
1
2
3
4# 添加自定义属性
data.custom_attr = "This is a test graph"
# 查看所有属性
print(data) # 输出: Data(x=[4,3], edge_index=[2,3], y=[4], custom_attr='...')图结构分析
方法 说明 data.num_nodes
返回节点数(若未显式设置 x
,需通过edge_index
推断)data.num_edges
返回边数(等于 edge_index.shape[1]
)data.is_directed()
判断是否为有向图(若存在反向边则为无向图)
异构图的创建
异构图需要使用HeteroData
:
1 |
|
定义节点类型与特征:
1 |
|
添加边索引:使用三元组(源节点类型,边类型,目标节点类型)
定义边,并分配COO格式的边索引(形状为[2, num_edges]
)
1 |
|
添加边特征(可选):可以为边添加特征
1 |
|
验证异构图结构:打印HeteroData
对象查看结构信息
1 |
|
输出示例:
1 |
|
检查节点和边的数量:
1 |
|
转换为GPU张量(可选):
1 |
|
注意:如果某类节点无特征,可以不设置.x
属性,但某些模型可能需要初始化(例如使用全零张量)
模型构建
GCN
PyG中的GCNConv
是GCN的核心层,
1 |
|
CGNConv()
内部不包含激活函数,需手动添加。
定义一个两层的GCN:
1 |
|
RGCN
PyG中的RGCNConv
是RGCN的核心层:
1 |
|
num_bases
:基分解,通过共享基矩阵\(\mathbf{V}_b\)组合生成各种关系的\(\mathbf{W}_r\): \[ \mathbf{W}_r = \sum_{b=1}^B a_{rb}\mathbf{V}_b \] 适用场景:关系间存在潜在共性(如社交网络中不同互动类型有相似模式)num_blocks
:块对角分解,将权重矩阵拆分为块对角结构,减少稠密参数 \[ \mathbf{W}_r = \mathrm{diag}(\mathbf{Q}_{1r},\cdots,\mathbf{Q}_{Br}) \] 适用场景:关系类型差异较大,需保持参数独立性
RGCNConv()
对象输入数据格式:
- 节点特征矩阵
x
,维度为[num_nodes, in_channels]
(如果是异构图,有不同类型的节点,需要将它们合并) - 边索引
edge_index
,维度为[2, num_edges]
(如果是异构图,有不同类型的边,需要将它们合并) - 边类型
edge_type
,维度为[num_edges]
(值范围0
到num_relations-1
,第\(i\)个值代表上面边索引edge_index
中第\(i\)条边的类型)
由此可以看出,RGCNConv()
仅区分边类型,不区分节点类型。
RGCNConv()
内部并不包含激活函数,需手动添加。
示例:
假设一个社交网络,包含3个用户,两种关秀(关注(类型0)、好友(类型1))
1 |
|