torch.nn.init
torch.nn.init
是Pytorch提供的一个模块,用于初始化神经网络模型的权重和偏置。
常数初始化
1 nn.init.constant_(tensor, val)
正态分布初始化
1 nn.init.normal_(tensor, mean=0.0 , std=1.0 )
使用均值为mean
,标准差为std
的正态分布来初始化tensor
均匀分布初始化
1 nn.init.uniform_(tensor, a=0.0 , b=1.0 )
使用区间[a, b)
上的均匀分布来初始化tensor
Xavier初始化(Glorot初始化)
1 2 nn.init.xavier_uniform_(tensor, gain=1.0 ) nn.init.xavier_normal_(tensor, gain=1.0 )
Xavier初始化旨在保持输入和输出的方差一致,适用于Sigmoid
和Tanh
激活函数
xavier_uniform_
使用均匀分布,xavier_normal_
使用正态分布
gain
是一个可缩放因子,用于调整分布的标准差:
在Xavier初始化中,权重的标准差计算公式为:std_dev = gain * sqrt(2.0 / (fan_in, fan_out))
,其中,fan_in
是权重的输入单元数,fan_out
是权重的输出单元数。
对于Tanh
激活函数,gain
通常为1;对于Sigmoid
激活函数,gain
通常为4;对于ReLU
激活函数,gain
通常为\(\sqrt{2}\)
He初始化
1 2 nn.init.kaiming_uniform_(tensor, a=0 , mode='fan_in' , nonlinearity='leaky_relu' ) nn.init.kaiming_normal_(tensor, a=0 , mode='fan_in' , nonlinearity='leaky_relu' )
He初始化适用于ReLU和其变体(如Leaky ReLU)激活函数
kaiming_uniform_
使用均匀分布,kaiming_normal_
使用正态分布
a
是激活函数的负斜率(仅用于Leaky ReLU)
mode
可以是fan_in
(默认)或fan_out
,表示使用输入单位数或输出单位数来计算标准差
nonlinearity
是使用激活函数的类型,可以为relu
、leaky_relu
、linear
等
Orthogonal初始化
1 nn.init.orthogonal_(tensor, gain=1 )
Orthogonal 初始化生成一个正交矩阵,适用于 RNN 和 LSTM
等循环网络
gain
是一个可选的缩放因子
单位矩阵初始化
Dirac初始化
1 nn.init.dirac_(tensor, groups=1 )
Dirac 初始化用于卷积层,将权重初始化为 Dirac delta
函数,即一个中心为 1、其余为 0 的矩阵。
全零/全一初始化
1 2 nn.init.zeros_(tensor) nn.init.ones_(tensor)
应用初始化
可以使用apply()
方法应用或者直接调用参数。下面是直接调用的方法:
1 2 3 4 5 for name, param in model.named_parameters(): if 'weight' in name: nn.init.xavier_normal_(param) elif 'bias' in name: nn.init.zeros_(param)
apply()
方法实现初始化
神经网络实例.apply()
是Pytorch中的一种常用方法,用于将一个函数应用到模型的所有子模块上
apply()
方法接收一个函数作为参数,并将这个函数应用到所有子模块(即该模型中所有nn.Module
的实例)上。每个子模块都会被传递给这个函数,以便进行相应的操作。
比如:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 import torchimport torch.nn as nnimport torch.nn.functional as Fclass CustomNet (nn.Module): def __init__ (self ): super ().__init__() self .conv1 = nn.Conv2d(1 , 32 , 3 , 1 ) self .conv2 = nn.Conv2d(32 , 64 , 3 , 1 ) self .fc1 = nn.Linear(9216 , 128 ) self .fc2 = nn.Linear(128 , 10 ) def forward (self, x ): x = self .conv1(x) x = F.relu(x) x = self .conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2 ) x = torch.flatten(x, 1 ) x = self .fc1(x) x = F.relu(x) x = self .fc2(x) return F.log_softmax(x, dim=1 ) def init_weights (m ): if isinstance (m, nn.Conv2d) or isinstance (m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None : nn.init.zeros_(m.bias) model = CustomNet() model.apply(init_weights)