常用方法
torch.nn.Flatten
nn.Flatten
用于将输入数据展平。在深度学习中,尤其是在处理图像数据时,我们经常需要将多维的图像数据展平为一维向量,以便将其输入到全连接层(也称为线性层)中进行进一步的处理。
1 |
|
start_dim
:开始展平的维度。默认为1,意味着从第二个维度开始展平(在PyTorch中,第一个维度通常是批次大小)end_dim
:结束展平的维度。默认为-1,意味着展平到最后一个维度
比如:
1 |
|
输出结果:
1 |
|
torch.nn.BatchNorm2d
1 |
|
num_features
:特征数量,即通道数(C)。例如,对于RGB图像,这个值应该是3。eps
:一个很小的数值\(\epsilon\),用于数值稳定,默认为1e-5。momentum
:动量值,用于计算运行平均值和方差,默认为0.1。affine
:一个布尔值,当设置为True
时,批量归一化层会有可学习的参数(权重和偏置),默认为True
。track_running_stats
:一个布尔值,当设置为True
时,会跟踪运行平均值和方差,默认为True
。
在训练过程中,nn.BatchNorm2d
会维护一个运行平均值和方差,用于在评估模式(如推理时)对数据进行归一化。这些值是通过动量(momentum)更新的。
注意:
nn.BatchNorm2d
应在卷积层之后立即使用,以标准化其输出。- 在训练模式(
model.train()
)和评估模式(model.eval()
)下,nn.BatchNorm2d
的行为是不同的。在训练模式下,它使用当前批次的统计数据;在评估模式下,它使用运行统计数据。 - 批量归一化对于小批量大小可能效果不佳,因为批次统计数据可能不够稳定。
比如:
1 |
|
torch.nn.BatchNorm1d
1 |
|
num_features
:特征数量,即通道数(C)。例如,对于RGB图像,这个值应该是3。eps
:一个很小的数值\(\epsilon\),用于数值稳定,默认为1e-5。momentum
:动量值,用于计算运行平均值和方差,默认为0.1。affine
:一个布尔值,当设置为True
时,批量归一化层会有可学习的参数(权重和偏置),默认为True
。track_running_stats
:一个布尔值,当设置为True
时,会跟踪运行平均值和方差,默认为True
。
注意:
nn.BatchNorm1d
应在完全连接层(线性层)或其他一维特征处理层之后立即使用,以标准化其输出。- 在训练模式(
model.train()
)和评估模式(model.eval()
)下,nn.BatchNorm1d
的行为是不同的。在训练模式下,它使用当前批次的统计数据;在评估模式下,它使用运行统计数据。 - 批量归一化对于小批量大小可能效果不佳,因为批次统计数据可能不够稳定。
nn.BatchNorm1d
可以用于处理时间序列数据,但在这种情况下,通常需要确保批量大小足够大,以便获得稳定的统计数据。
比如:
1 |
|
torch.permute()
1 |
|
input
:要重新排列维度的输入张量dims
:一个包含整数的新维度顺序的元组或列表。dims
中的每个元素都应该是唯一的,并且范围在0
到input.dim() - 1
之间
torch.permute()
通过重新排列指定张量的维度来工作。我们提供一个新维度顺序,permute
会根据这个顺序返回一个新的张量,其数据与原始张量相同,但维度顺序不同。
比如:
1 |
|
输出结果:
1 |
|
上面的代码将x
张量的第二个维度和第三个维度进行了互换,即将每一个矩阵的高和宽进行了互换,相当于原本的x[c][h][w]
变成了x[c][w][h]
,如果想要具体的元素对应,新张量的最后一维:x[c, w, :]
与原张量的最后一维x[c, :, h]
元素相同
在图像处理中,经常需要将通道维度移动到第一个维度,即将[H, W, C]
变为[C, H, W]
维度,以适应一些模型的输入要求,比如:
1 |
|
输出结果:
1 |
|
contiguous()
torch.Tensor.contiguous()
方法用于确保张量在内存中的存储是连续的。
可能导致张量存储不连续的操作:
- 切片操作:比如,
x = torch.tensor([1, 2, 3, 4, 5])
,y = x[1::2]
(取奇数位置的元素) - 转置操作:比如,
x = torch.tensor([[1, 2], [3, 4]])
,y = x.t()
)(转置) - 变形操作:比如,
x = torch.tensor([1, 2, 3, 4])
,y = x.view(2, 2)
(改变形状)
这些操作可能会改变张量在内存中的存储顺序,导致存储不连续
判断张量是否连续可以使用is_contiguous
方法:
1 |
|
输出结果:
1 |
|
当调用contiguous
方法时,PyTorch会检查张量的存储是否连续,如果不连续,则会创建一个新的连续的张量,并将原张量的数据复制到新张量中。
由于contiguous
涉及到数据复制,因此在处理大型张量时可能会带来性能开销,在实际应用中,应尽量避免不必要的contiguous
调用,以减小性能开销。
view()
操作需要保证张量的连续性。如果张量不连续,view()
操作会报错,比如:
1 |
|
输出结果:
1 |
|
view()
view()
方法用于改变张量的形状而不改变其数据。这个方法返回一个新的张量,该张量与原张量共享内存。view()
方法可用将多维张量进行展平。
view()
方法与reshape()
方法使用方式相同。但view()
方法与reshape()
方法有一定的区别:
view()
:要求张量是连续的。如果张量不连续,需要先调用.contiguous()
。reshape()
:可以处理不连续的张量,并且会返回一个连续的视图。reshape()
更灵活,因为它可以自动处理不连续的张量。如果原始张量在内存中不连续,reshape()
会在内部调用contiguous()
方法,自动处理内存连续型问题。
如果确定张量是连续的,并且追求更高的性能,可以使用
view()
方法。如果不确定张量是否连续,或者需要处理经过复杂操作后的张量,建议使用
reshape()
方法。
1 |
|
torch.nn.ModuleList()
nn.ModuleList()
是PyTorch中用于存储和管理多个子模块(即nn.Module
的实例)。它类似于Python的内置列表,但专门为PyTorch的神经网络模块设计,具有一些特殊的属性和行为。
主要特点:
- 自动注册子模块:
- 当我们将子模块添加到
nn.ModuleList()
中时,这些子模块会被自动注册为父模块的子模块。这意味着它们的状态(如参数和缓冲区)会被包含在父模块的状态中,并且可以通过父模块的state_dict()
方法进行保存和加载。
- 当我们将子模块添加到
- 支持索引和迭代:
- 我们可以使用索引来访问
nn.ModuleList()
中特定的子模块,也可以使用循环来迭代所有子模块
- 我们可以使用索引来访问
- 动态修改:
- 与
nn.Sequential()
不同,nn.ModuleList()
允许在运行时动态地添加或者删除子模块。
- 与
- 不参与前向传播:
nn.ModuleList()
本身不定义前向传播逻辑。它只是一个存储和管理子模块的容器。我们需要手动实现如何使用这些子模块进行前向传播。
注意事项:
- 不要直接修改列表:直接通过索引赋值(比如:
module_list[0] = new_module
)可能会导致问题,因为其不会正确地注册新模块。应该使用append()
或extend()
方法来添加模块,或者使用del
来删除模块
基本用法:
1 |
|
在定义nn.ModuleList()
时,可以传入子模块列表:
1 |
|
Tensor.repeat()
用于张量维度扩展和数据复制,其核心功能是根据指定的重复次数沿各维度复制张量元素
比如,原始张量a
为:
1 |
|
结果:
1 |
|
a
的维度为[2, 2, 2]
让a
在第0维度(最外面的维度)复制一次:
1 |
|
结果:
1 |
|
维度变为[4, 2, 2]
让a
在第1维度复制两次:
1 |
|
结果:
1 |
|
维度变为[2, 6, 2]
让a
在第2维度复制一次:
1 |
|
结果:
1 |
|
维度变为[2, 2, 4]
让a
扩展一个维度并在新的第0维度复制一次:
1 |
|
结果:
1 |
|
维度变为[2, 2, 2, 2]
让a
同时在第1维度和第2维度复制一次:
1 |
|
结果:
1 |
|
相当于执行一次a.repeat(1, 1, 2)
再执行一次a.repeat(1, 2, 1)
维度变为[2, 4, 4]
注意:如果a
本身有n个维度,那么a.repeat()
中至少要有n个维度的数据,否则会报错
torch.tril()
torch.tril()
是PyTorch中用于生成下三角矩阵的函数,它保留矩阵主对角线即以下的元素,其余位置变为0。适用于任何维度的张量,但处理针对最后两个维度
参数说明:
input
:(Tensor)输入张量,至少二维diagonal
:(int,可选)对角线偏移量,默认为0diagonal=0
:保留主对角线及以下diagonal>0
:向上偏移,保留更多上方元素diagonal<0
:向下偏移,保留更少元素
out
:(Tensor,可选)输出张量
比如:
1 |
|
输出结果:
1 |
|
又比如:
1 |
|
输出结果:
1 |
|