参数管理

状态字典

state_dict()方法用于返回模型所有参数和缓冲区的字典

state_dict 是一个字典,包含了模型的所有可学习参数(如权重和偏置)以及持久性缓冲区(如批量归一化的运行均值和方差)。这个字典的键通常是参数的名称。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(1, 2)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2, 1)
self.w1 = torch.tensor([1.0], requires_grad=True)
self.w2 = nn.Parameter(torch.tensor([2.0]))

def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

mlp = MLP()
print(mlp.state_dict())

输出结果:

1
2
OrderedDict([('w2', tensor([2.])), ('linear1.weight', tensor([[-0.7384],
[-0.5803]])), ('linear1.bias', tensor([ 0.9125, -0.5212])), ('linear2.weight', tensor([[-0.0357, -0.5780]])), ('linear2.bias', tensor([0.0804]))])

由此可知,直接使用torch.tensor()不会被状态字典注册参数,使用nn.Parameter()方法可以被状态字典注册参数,从而可以使用paramenters()或者named_parameters()按顺序访问到参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(1, 2)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2, 1)
self.w1 = torch.tensor([1.0], requires_grad=True)
self.w2 = nn.Parameter(torch.tensor([2.0]))

def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

mlp = MLP()
for name, param in mlp.named_parameters():
print(name, param)

输出结果:

1
2
3
4
5
6
7
8
9
10
11
w2 Parameter containing:
tensor([2.], requires_grad=True)
linear1.weight Parameter containing:
tensor([[ 0.3866],
[-0.6693]], requires_grad=True)
linear1.bias Parameter containing:
tensor([-0.4819, 0.4518], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.3800, 0.4670]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.1688], requires_grad=True)

比如:

1
2
3
4
5
6
7
8
9
import torch.nn as nn

model = nn.Sequential(
nn.Linear(1, 2),
nn.ReLU(),
nn.Linear(2, 1)
)

print(model.state_dict())

输出结果:

1
2
3
4
5
6
OrderedDict([('0.weight',
tensor([[-0.4336],
[ 0.8008]])),
('0.bias', tensor([-0.1035, -0.3139])),
('2.weight', tensor([[-0.0643, -0.6914]])),
('2.bias', tensor([-0.4783]))])

通过状态字典访问参数:

1
2
3
4
5
6
7
8
9
import torch.nn as nn

model = nn.Sequential(
nn.Linear(1, 2),
nn.ReLU(),
nn.Linear(2, 1)
)

print(model.state_dict()['0.bias'])

输出结果:

1
tensor([ 0.2850, -0.7601])

参数访问

直接访问

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(1, 2)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2, 1)

def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x

mlp = MLP()
print('mlp.linear1.weight:\n', mlp.linear1.weight, end='\n\n')
print('mlp.linear1.bias:\n', mlp.linear1.bias)

输出结果:

1
2
3
4
5
6
7
8
mlp.linear1.weight:
Parameter containing:
tensor([[-0.6804],
[ 0.5881]], requires_grad=True)

mlp.linear1.bias:
Parameter containing:
tensor([-0.6178, -0.5424], requires_grad=True)

如果是使用nn.Sequential()注册的模块,可以使用索引的方式访问某个子模块,再直接访问:

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(1, 2),
nn.ReLU(),
nn.Linear(2, 1)
)

print(model[0])
print(model[0].weight)

输出结果:

1
2
3
4
Linear(in_features=1, out_features=2, bias=True)
Parameter containing:
tensor([[ 0.0674],
[-0.5861]], requires_grad=True)

按顺序遍历

使用parameters()或者named_parameters()方法,前者按顺序访问被状态字典注册的参数的值,后者访问名字和值

比如:

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(1, 2),
nn.ReLU(),
nn.Linear(2, 1)
)

for param in model.parameters():
print(param)

输出结果:

1
2
3
4
5
6
7
8
9
Parameter containing:
tensor([[0.8264],
[0.7365]], requires_grad=True)
Parameter containing:
tensor([ 0.2989, -0.9195], requires_grad=True)
Parameter containing:
tensor([[-0.4156, -0.5078]], requires_grad=True)
Parameter containing:
tensor([0.3763], requires_grad=True)

比如:

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(1, 2),
nn.ReLU(),
nn.Linear(2, 1)
)

for param in model.named_parameters():
print(param)

输出结果:

1
2
3
4
5
6
7
8
9
('0.weight', Parameter containing:
tensor([[-0.5772],
[-0.4926]], requires_grad=True))
('0.bias', Parameter containing:
tensor([ 0.4068, -0.3668], requires_grad=True))
('2.weight', Parameter containing:
tensor([[-0.2434, 0.4881]], requires_grad=True))
('2.bias', Parameter containing:
tensor([-0.5687], requires_grad=True))

参数注册

使用nn.Parameter()或者nn.Module.register_parameter()在模块的状态字典中注册参数,前者注册会将注册的变量名作为参数名,注册的张量值作为参数值;后者接受两个参数,第一个为参数名,第二个为nn.Parameter()参数或者None

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn

class Linear(nn.Module):
def __init__(self, num_input, num_output):
super().__init__()
self.w = nn.Parameter(torch.randn(num_input, num_output))
self.b = nn.Parameter(torch.zeros(num_output))

def forward(self, x):
return torch.matmul(x, self.w) + self.b

lin = Linear(1, 2)
print(lin.state_dict())

输出结果:

1
OrderedDict([('w', tensor([[ 0.1659, -0.6305]])), ('b', tensor([0., 0.]))])

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class Linear(nn.Module):
def __init__(self, num_input, num_output):
super().__init__()
w = torch.randn(num_input, num_output)
b = torch.zeros(num_output)
self.register_parameter('weight', nn.Parameter(w))
self.register_parameter('bias', nn.Parameter(b))

def forward(self, x):
return torch.matmul(x, self.weight) + self.bias

lin = Linear(2, 1)
print(lin.state_dict())

输出结果:

1
2
OrderedDict([('weight', tensor([[-0.7838],
[-2.3011]])), ('bias', tensor([0.]))])

参数绑定

有时候我们希望在两处共享参数或模块,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn
import torch.nn.functional as F

shared = nn.Linear(8, 8)
net = nn.Sequential(
nn.Linear(4, 8),
nn.ReLU(),
shared,
nn.ReLU(),
shared,
nn.ReLU(),
nn.Linear(8, 1)
)
print(net[2].weight.data[0] == net[4].weight.data[0])
net[2].weight.data[0, 0] = 100
print(net[2].weight.data[0] == net[4].weight.data[0])

输出结果:

1
2
tensor([True, True, True, True, True, True, True, True])
tensor([True, True, True, True, True, True, True, True])

这样就实现了参数共享或者模块共享

参数保存和加载

保存

使用torch.save()的方式保存模型或参数,使用torch.load()加载模型或参数

1
torch.save(obj, f, pickle_protocol=2, _use_new_zipfile_serialization=None)
  • obj:要保存的 Python 对象。这个对象可以是任何可以被 Python 的 pickle 模块序列化的对象,包括但不限于 PyTorch 的模型、张量(torch.Tensor)、字典、列表等。
  • f:文件名或文件句柄,指定保存对象的位置。可以是一个字符串(文件路径)或一个具有 write() 方法的文件句柄。
  • pickle_protocol(可选):指定用于序列化的pickle协议版本。默认是 2,但我们可以选择更高的版本以获得更好的性能或兼容性。例如,pickle_protocol=4 可以保存大于 4GB 的对象。
  • **_use_new_zipfile_serialization**(可选):这是一个内部使用的参数,用于控制是否使用新的 ZIP 文件序列化格式。默认值是 None,通常不需要手动设置。

Pickle模块定义了多种协议版本,每个版本都有其特定的功能和限制:

  1. 协议0:原始的“人类可读”协议,兼容性最好,但效率较低。
  2. 协议1:旧的二进制格式,比协议0更高效。
  3. 协议2:引入了新的操作符,支持更高效的序列化。
  4. 协议3:引入了字节对象的支持,可以序列化字节字符串。
  5. 协议4:引入了大数据支持,可以序列化大于4GB的对象。
  6. 协议5:引入了带外数据支持,用于优化序列化性能。

torch.save() 函数中,pickle_protocol 参数允许我们指定使用哪个pickle协议版本来进行序列化。默认情况下,PyTorch使用协议2,因为它在大多数情况下提供了良好的性能和兼容性。

保存整个模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们有一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(20 * 12 * 12, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 20 * 12 * 12)
x = F.relu(self.fc1(x))
return x

model = MyModel()
# 假设模型已经训练完成

# 保存整个模型
torch.save(model, 'model whole.pth')

保存模型的状态字典

通常,我们只保存模型的状态字典(state_dict),它包含了模型的所有参数和缓冲区。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设我们有一个简单的模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(20 * 12 * 12, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 20 * 12 * 12)
x = F.relu(self.fc1(x))
return x

model = MyModel()
# 假设模型已经训练完成

# 保存模型的状态字典
torch.save(model.state_dict(), 'model state_dict.pth')

加载

加载整个模型

1
torch.load(fname, map_location=None, pickle_module=None, pickle_load_args=None, torch_device=None, weights_only=False, strict=True)
  1. fname (str or file-like object) - 必选参数
    • 指定要加载的文件名或文件对象。可以是字符串路径或任何支持 read() 方法的文件对象。
  2. map_location (str, function, dict, or torch.device) - 可选参数
    • 指定如何重新映射存储位置。例如,可以将模型从GPU加载到CPU,或者从一台机器的GPU加载到另一台机器的GPU。
    • 可以是字符串(如 'cpu', 'cuda:0'),函数,字典或 torch.device 对象。
    • 默认值为 None,表示使用保存时的设备。
  3. pickle_module (module) - 可选参数
    • 指定用于反序列化的pickle模块。默认使用 pickle 模块。
    • 允许使用自定义的pickle模块,例如用于兼容性或安全性目的。
  4. pickle_load_args (dict) - 可选参数
    • 指定传递给pickle模块的额外参数。这些参数将直接传递给 pickle.load() 函数。
    • 默认值为 None
  5. torch_device (str or torch.device) - 可选参数
    • 指定加载张量时使用的设备。这个参数是 map_location 的简写形式,用于简单的情况。
    • 例如,'cpu', 'cuda:0'torch.device('cuda:0')
  6. weights_only (bool) - 可选参数
    • 指定是否仅加载模型的权重。如果设置为 True,则不会加载优化器状态和其他非权重信息。
    • 默认值为 False
  7. strict (bool) - 可选参数
    • 当加载模型权重时,指定是否严格匹配键名。如果设置为 True,则所有键名必须匹配,否则会抛出错误。
    • 默认值为 True

比如:

1
2
3
4
5
6
import torch
import torch.nn as nn

# 加载整个模型
model_loaded = torch.load('model whole.pth')
model_loaded.eval() # 设置为评估模式

加载模型的状态字典

如果我们只保存了状态字典,我们需要先创建一个与保存时结构相同的模型实例,然后加载状态字典

1
model.load_state_dict(state_dict, strict=True)
  1. state_dict (dict) - 必选参数
    • 一个包含模型参数和缓冲区的字典。这个字典通常是通过 state_dict() 方法获取的,也可以是通过 torch.save() 保存并加载的。
  2. strict (bool) - 可选参数
    • 指定是否严格匹配键名。如果设置为 True,则 state_dict 中的键必须与模块的 state_dict() 返回的键完全匹配,否则会抛出错误。如果设置为 False,则允许键名不匹配,未匹配的键将被忽略。
    • 默认值为 True

比如:

1
2
3
4
5
6
import torch
import torch.nn as nn

new_model = ......
# 如果在GPU上保存,但在CPU上加载
new_model.load_state_dict(torch.load('model state_dict.pth', map_location=torch.device('cpu')))

参数管理
https://blog.shinebook.net/2025/03/04/人工智能/pytorch/参数管理/
作者
X
发布于
2025年3月4日
许可协议