状态字典
state_dict()
方法用于返回模型所有参数和缓冲区的字典
state_dict
是一个字典,包含了模型的所有可学习参数(如权重和偏置)以及持久性缓冲区(如批量归一化的运行均值和方差)。这个字典的键通常是参数的名称。
比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| import torch import torch.nn
class MLP(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(1, 2) self.relu = nn.ReLU() self.linear2 = nn.Linear(2, 1) self.w1 = torch.tensor([1.0], requires_grad=True) self.w2 = nn.Parameter(torch.tensor([2.0])) def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x
mlp = MLP() print(mlp.state_dict())
|
输出结果:
1 2
| OrderedDict([('w2', tensor([2.])), ('linear1.weight', tensor([[-0.7384], [-0.5803]])), ('linear1.bias', tensor([ 0.9125, -0.5212])), ('linear2.weight', tensor([[-0.0357, -0.5780]])), ('linear2.bias', tensor([0.0804]))])
|
由此可知,直接使用torch.tensor()
不会被状态字典注册参数,使用nn.Parameter()
方法可以被状态字典注册参数,从而可以使用paramenters()
或者named_parameters()
按顺序访问到参数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch import torch.nn
class MLP(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(1, 2) self.relu = nn.ReLU() self.linear2 = nn.Linear(2, 1) self.w1 = torch.tensor([1.0], requires_grad=True) self.w2 = nn.Parameter(torch.tensor([2.0])) def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x
mlp = MLP() for name, param in mlp.named_parameters(): print(name, param)
|
输出结果:
1 2 3 4 5 6 7 8 9 10 11
| w2 Parameter containing: tensor([2.], requires_grad=True) linear1.weight Parameter containing: tensor([[ 0.3866], [-0.6693]], requires_grad=True) linear1.bias Parameter containing: tensor([-0.4819, 0.4518], requires_grad=True) linear2.weight Parameter containing: tensor([[-0.3800, 0.4670]], requires_grad=True) linear2.bias Parameter containing: tensor([0.1688], requires_grad=True)
|
比如:
1 2 3 4 5 6 7 8 9
| import torch.nn as nn
model = nn.Sequential( nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1) )
print(model.state_dict())
|
输出结果:
1 2 3 4 5 6
| OrderedDict([('0.weight', tensor([[-0.4336], [ 0.8008]])), ('0.bias', tensor([-0.1035, -0.3139])), ('2.weight', tensor([[-0.0643, -0.6914]])), ('2.bias', tensor([-0.4783]))])
|
通过状态字典访问参数:
1 2 3 4 5 6 7 8 9
| import torch.nn as nn
model = nn.Sequential( nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1) )
print(model.state_dict()['0.bias'])
|
输出结果:
1
| tensor([ 0.2850, -0.7601])
|
参数访问
直接访问
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| import torch import torch.nn
class MLP(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(1, 2) self.relu = nn.ReLU() self.linear2 = nn.Linear(2, 1) def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x
mlp = MLP() print('mlp.linear1.weight:\n', mlp.linear1.weight, end='\n\n') print('mlp.linear1.bias:\n', mlp.linear1.bias)
|
输出结果:
1 2 3 4 5 6 7 8
| mlp.linear1.weight: Parameter containing: tensor([[-0.6804], [ 0.5881]], requires_grad=True)
mlp.linear1.bias: Parameter containing: tensor([-0.6178, -0.5424], requires_grad=True)
|
如果是使用nn.Sequential()
注册的模块,可以使用索引的方式访问某个子模块,再直接访问:
1 2 3 4 5 6 7 8 9 10 11
| import torch import torch.nn as nn
model = nn.Sequential( nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1) )
print(model[0]) print(model[0].weight)
|
输出结果:
1 2 3 4
| Linear(in_features=1, out_features=2, bias=True) Parameter containing: tensor([[ 0.0674], [-0.5861]], requires_grad=True)
|
按顺序遍历
使用parameters()
或者named_parameters()
方法,前者按顺序访问被状态字典注册的参数的值,后者访问名字和值
比如:
1 2 3 4 5 6 7 8 9 10 11
| import torch import torch.nn as nn
model = nn.Sequential( nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1) )
for param in model.parameters(): print(param)
|
输出结果:
1 2 3 4 5 6 7 8 9
| Parameter containing: tensor([[0.8264], [0.7365]], requires_grad=True) Parameter containing: tensor([ 0.2989, -0.9195], requires_grad=True) Parameter containing: tensor([[-0.4156, -0.5078]], requires_grad=True) Parameter containing: tensor([0.3763], requires_grad=True)
|
比如:
1 2 3 4 5 6 7 8 9 10 11
| import torch import torch.nn as nn
model = nn.Sequential( nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1) )
for param in model.named_parameters(): print(param)
|
输出结果:
1 2 3 4 5 6 7 8 9
| ('0.weight', Parameter containing: tensor([[-0.5772], [-0.4926]], requires_grad=True)) ('0.bias', Parameter containing: tensor([ 0.4068, -0.3668], requires_grad=True)) ('2.weight', Parameter containing: tensor([[-0.2434, 0.4881]], requires_grad=True)) ('2.bias', Parameter containing: tensor([-0.5687], requires_grad=True))
|
参数注册
使用nn.Parameter()
或者nn.Module.register_parameter()
在模块的状态字典中注册参数,前者注册会将注册的变量名作为参数名,注册的张量值作为参数值;后者接受两个参数,第一个为参数名,第二个为nn.Parameter()
参数或者None
比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| import torch import torch.nn as nn
class Linear(nn.Module): def __init__(self, num_input, num_output): super().__init__() self.w = nn.Parameter(torch.randn(num_input, num_output)) self.b = nn.Parameter(torch.zeros(num_output)) def forward(self, x): return torch.matmul(x, self.w) + self.b
lin = Linear(1, 2) print(lin.state_dict())
|
输出结果:
1
| OrderedDict([('w', tensor([[ 0.1659, -0.6305]])), ('b', tensor([0., 0.]))])
|
比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| import torch import torch.nn as nn
class Linear(nn.Module): def __init__(self, num_input, num_output): super().__init__() w = torch.randn(num_input, num_output) b = torch.zeros(num_output) self.register_parameter('weight', nn.Parameter(w)) self.register_parameter('bias', nn.Parameter(b)) def forward(self, x): return torch.matmul(x, self.weight) + self.bias lin = Linear(2, 1) print(lin.state_dict())
|
输出结果:
1 2
| OrderedDict([('weight', tensor([[-0.7838], [-2.3011]])), ('bias', tensor([0.]))])
|
参数绑定
有时候我们希望在两处共享参数或模块,
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import torch import torch.nn as nn import torch.nn.functional as F
shared = nn.Linear(8, 8) net = nn.Sequential( nn.Linear(4, 8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8, 1) ) print(net[2].weight.data[0] == net[4].weight.data[0]) net[2].weight.data[0, 0] = 100 print(net[2].weight.data[0] == net[4].weight.data[0])
|
输出结果:
1 2
| tensor([True, True, True, True, True, True, True, True]) tensor([True, True, True, True, True, True, True, True])
|
这样就实现了参数共享或者模块共享
参数保存和加载
保存
使用torch.save()
的方式保存模型或参数,使用torch.load()
加载模型或参数
1
| torch.save(obj, f, pickle_protocol=2, _use_new_zipfile_serialization=None)
|
- obj:要保存的 Python 对象。这个对象可以是任何可以被
Python 的
pickle
模块序列化的对象,包括但不限于 PyTorch
的模型、张量(torch.Tensor
)、字典、列表等。
- f:文件名或文件句柄,指定保存对象的位置。可以是一个字符串(文件路径)或一个具有
write()
方法的文件句柄。
- pickle_protocol(可选):指定用于序列化的pickle协议版本。默认是
2
,但我们可以选择更高的版本以获得更好的性能或兼容性。例如,pickle_protocol=4
可以保存大于 4GB 的对象。
- **_use_new_zipfile_serialization**(可选):这是一个内部使用的参数,用于控制是否使用新的
ZIP 文件序列化格式。默认值是
None
,通常不需要手动设置。
Pickle模块定义了多种协议版本,每个版本都有其特定的功能和限制:
- 协议0:原始的“人类可读”协议,兼容性最好,但效率较低。
- 协议1:旧的二进制格式,比协议0更高效。
- 协议2:引入了新的操作符,支持更高效的序列化。
- 协议3:引入了字节对象的支持,可以序列化字节字符串。
- 协议4:引入了大数据支持,可以序列化大于4GB的对象。
- 协议5:引入了带外数据支持,用于优化序列化性能。
在 torch.save()
函数中,pickle_protocol
参数允许我们指定使用哪个pickle协议版本来进行序列化。默认情况下,PyTorch使用协议2,因为它在大多数情况下提供了良好的性能和兼容性。
保存整个模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import torch import torch.nn as nn import torch.nn.functional as F
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(20 * 12 * 12, 10)
def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = x.view(-1, 20 * 12 * 12) x = F.relu(self.fc1(x)) return x
model = MyModel()
torch.save(model, 'model whole.pth')
|
保存模型的状态字典
通常,我们只保存模型的状态字典(state_dict
),它包含了模型的所有参数和缓冲区。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| import torch import torch.nn as nn import torch.nn.functional as F
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(20 * 12 * 12, 10)
def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = x.view(-1, 20 * 12 * 12) x = F.relu(self.fc1(x)) return x
model = MyModel()
torch.save(model.state_dict(), 'model state_dict.pth')
|
加载
加载整个模型
1
| torch.load(fname, map_location=None, pickle_module=None, pickle_load_args=None, torch_device=None, weights_only=False, strict=True)
|
- fname (str or file-like object) - 必选参数
- 指定要加载的文件名或文件对象。可以是字符串路径或任何支持
read()
方法的文件对象。
- map_location (str, function, dict, or torch.device)
- 可选参数
- 指定如何重新映射存储位置。例如,可以将模型从GPU加载到CPU,或者从一台机器的GPU加载到另一台机器的GPU。
- 可以是字符串(如
'cpu'
,
'cuda:0'
),函数,字典或 torch.device
对象。
- 默认值为
None
,表示使用保存时的设备。
- pickle_module (module) - 可选参数
- 指定用于反序列化的pickle模块。默认使用
pickle
模块。
- 允许使用自定义的pickle模块,例如用于兼容性或安全性目的。
- pickle_load_args (dict) - 可选参数
- 指定传递给pickle模块的额外参数。这些参数将直接传递给
pickle.load()
函数。
- 默认值为
None
。
- torch_device (str or torch.device) - 可选参数
- 指定加载张量时使用的设备。这个参数是
map_location
的简写形式,用于简单的情况。
- 例如,
'cpu'
, 'cuda:0'
或
torch.device('cuda:0')
。
- weights_only (bool) - 可选参数
- 指定是否仅加载模型的权重。如果设置为
True
,则不会加载优化器状态和其他非权重信息。
- 默认值为
False
。
- strict (bool) - 可选参数
- 当加载模型权重时,指定是否严格匹配键名。如果设置为
True
,则所有键名必须匹配,否则会抛出错误。
- 默认值为
True
。
比如:
1 2 3 4 5 6
| import torch import torch.nn as nn
model_loaded = torch.load('model whole.pth') model_loaded.eval()
|
加载模型的状态字典
如果我们只保存了状态字典,我们需要先创建一个与保存时结构相同的模型实例,然后加载状态字典
1
| model.load_state_dict(state_dict, strict=True)
|
- state_dict (dict) - 必选参数
- 一个包含模型参数和缓冲区的字典。这个字典通常是通过
state_dict()
方法获取的,也可以是通过
torch.save()
保存并加载的。
- strict (bool) - 可选参数
- 指定是否严格匹配键名。如果设置为
True
,则
state_dict
中的键必须与模块的 state_dict()
返回的键完全匹配,否则会抛出错误。如果设置为
False
,则允许键名不匹配,未匹配的键将被忽略。
- 默认值为
True
。
比如:
1 2 3 4 5 6
| import torch import torch.nn as nn
new_model = ......
new_model.load_state_dict(torch.load('model state_dict.pth', map_location=torch.device('cpu')))
|