神经网络的评估模式和训练模式

评估模式和训练模式是pytorch神经网络两种不同的模型状态。

评估模式和训练模式

训练模式:

训练模式是模型在接收训练数据并进行学习时的状态。在这种模式下,模型会根据输入数据进行前向传播,计算损失,并通过反向传播更新权重。

特点:

  1. 权重更新: 模型的权重会根据损失函数的梯度进行更新。
  2. Dropout激活: 如果模型中包含了Dropout层,这些层会在前向传播时随机将一部分神经元的输出置零,以减少过拟合。
  3. Batch Normalization层行为: 在训练模式下,Batch Normalization层会使用当前批次的数据来计算均值和方差,用于标准化。

使用train()方法将神经网络模型设置为训练模式。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(10, 100),
nn.ReLU(),
nn.Linear(100, 10),
nn.ReLU(),
nn.Linear(10, 1)
)

dataloader = ......
optimizer = ......
loss_function = ......

# 训练模式
model.train()
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = loss_function(output, target)
loss.backward()
optimizer.step()

评估模式:

评估模式是模型在处理非训练数据(如验证集或测试集)时的状态。在这种模式下,模型不会更新权重,而是用于生成预测或评估性能

特点:

  1. 权重固定: 模型的权重不会更新,即不会进行反向传播和梯度下降。如果设置了model.eval()后,再使用optimizer.step(),会报错。
  2. Dropout关闭: Dropout层会保留所有神经元的输出,但会乘以保留概率(1-p),以保持输出规模与训练时相同。
  3. Batch Normalization层行为: 在评估模式下,Batch Normalization层会使用训练过程中累积的均值和方差来进行标准化,而不是当前批次的统计数据

使用eval()方法将神经网络模型设置为评估模式,配合with torch.no_grad()不计算梯度

比如:

1
2
3
4
5
model.eval()	# 设置模型为评估模式
with torch.no_grad(): # 不计算梯度
for data, target in test_loader:
output = model(data)
# ... 计算评估指标,如准确率 ...

torch.no_grad()

torch.no_grad()是Pytorch中用于控制梯度计算的一个上下文管理器,在torch.no_grad()的上下文中执行操作中,它会告诉Autograd暂时不需要为这些操作计算梯度。

使用场景:

  1. 模型评估: 在评估模型性能时,我们通常不需要计算梯度,因此可以使用torch.no_grad()来避免不必要的计算。
  2. 推理: 在部署模型进行实际推理时,同样不需要计算梯度,使用torch.no_grad()可以加快推理速度并减少内存消耗。
  3. 特定操作: 有时,在训练过程中,某些特定操作(如指标计算)不需要梯度信息,可以在这些操作的代码块中使用torch.no_grad()

torch.no_grad()通常与with一起使用:

1
2
3
4
5
6
7
8
9
10
11
import torch

x = torch.tensor([1.0, 2.0])
w = torch.tensor([2.2, -1.2], requires_grad=True)
b = torch.zeros(1, requires_grad=True)
with torch.no_grad():
y = torch.matmul(x, w) + b
try:
y.backward()
except Exception as e:
print(e)

错误:

1
element 0 of tensors does not require grad and does not have a grad_fn

神经网络的评估模式和训练模式
https://blog.shinebook.net/2025/03/07/人工智能/pytorch/神经网络的评估模式和训练模式/
作者
X
发布于
2025年3月7日
许可协议