torchvision

torchvision介绍

torchvision是一个与Pytorch深度学习框架紧密集成的计算机视觉库。它提供了许多用于图像处理、数据加载、模型预训练以及常见视觉任务(如图像分类、目标检测等)的工具和函数

torchvision模块

torchvision.datasets

用于加载特定数据集

这个模块提供了一组常用的视觉数据集,例如 CIFAR、MNIST、ImageNet 等。每个数据集都是一个 Dataset 对象,可以很容易地与 PyTorch 的数据加载器(DataLoader)一起使用。

torchvison.datasets提供了多种数据集,以下是一些常用的数据集:

  • CIFARCIFAR-10CIFAR-100数据集,用于图像分类任务
  • MNIST:手写数据集,用于图像分类任务
  • ImageNet:大规模视觉数据库,用于图像分类任务
  • COCO:用于目标检测、分割和标注的数据集
  • VOC:PASCAL Visual Object Classes 数据集,用于目标检测和分割任务
  • FasionMNIST:与MNIST类似但更复杂

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torchvison.datasets as datasets
import torchvison.transforms as transforms

transform = transform.Compose([
# 调整图像大小为 256x256
transforms.Resize(256),
# 中心裁剪得到 224x224 的图像
transforms.CenterCrop(224),
# 将图像转换为张量
transforms.ToTensor(),
# 对图像进行归一化,这里使用的是 ImageNet 数据集的均值和标准差
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载 CIFAR-10 训练集
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

参数解释:

  • root:数据集的根目录,即数据集存储在本地文件系统中的位置,用于存储下载的数据集
  • train:布尔值,指定是加载训练集还是测试集,如果是True,则加载训练集
  • download:布尔值,如果为True,则自动下载数据集
  • transform:一个或多个图像变换操作,用于预处理图像

其他数据集的使用方式与 CIFAR-10 类似,只是参数和变换可能有所不同。例如,加载 MNIST 数据集时,由于图像是灰度的,可能不需要归一化操作,或者需要将单通道图像扩展为三通道。

如果datasets中的数据集下载速度很慢,可以使用如下方式:

1
2
3
4
5
6
7
8
9
10
11
# 如果没有six模块,使用pip下载
from six.moves import urllib

# 设置代理地址和端口
proxy = urllib.request.ProxyHandler({'http': '127.0.0.1:1081', 'https': '127.0.0.1:1081'})

# 构建一个使用代理设置的新opener
opener = urllib.request.build_opener(proxy)

# 安装opener,使其成为全局urllib.request使用的opener
urllib.request.install_opener(opener)

然后再使用datasets下载:

1
2
3
import torchvision

dataset = torchvision.datasets.FasionMNIST('./data', download=True)

用于加载自定义数据集

可以使用torchvision.datasets.ImageFolder方法,适用于以下场景:

  • 图像分类任务,其中每个文件夹代表一个类别

ImageFolder假设数据集的文件夹结构是以类别命名的子文件夹,每个子文件夹中包含属于该类别的图像文件。ImageFolder 自动将文件夹名称作为类别标签,并加载图像数据。

数据集结构:

1
2
3
4
5
6
7
8
9
10
dataset/
class1/
img1.jpg
img2.jpg
...
class2/
img1.jpg
img2.jpg
...
...

在这里,class1class2 等是类别的名称,也是文件夹的名称。每个文件夹内部包含了属于该类别的所有图像。

在使用 ImageFolder 加载这个数据集时,class_1class_2class_3 将分别被自动标注为 0、1 和 2。这些整数索引就是模型训练时使用的类别标签。

主要参数:

  • root:数据集的根目录路径。
  • transform:一个函数,用于对图像进行预处理和增强。
  • target_transform:一个函数,用于对标签进行转换。
  • loader:用于加载图像的函数,默认是 default_loader,它使用 PIL 来加载图像。
  • is_valid_file:一个函数,用于判断文件是否有效(例如,根据文件扩展名)。

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
from torchvision import datasets, transforms

transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])

# 加载数据集
dataset = datasets.ImageFolder(root='path/to/data', transform=transform)

# 数据集的类别标签到索引的映射
class_to_idx = dataset.class_to_idx
print(class_to_idx)

输出:{'class_1': 0, 'class_2': 1, 'class_3': 2}

1
2
3
# 获取类别名称列表
classes = dataset.classes
print(classes)

输出:['class_1', 'class_2', 'class_3']

torchvision.transforms

torchvision.transforms 是一个提供了一系列图像变换(transform)操作的模块,这些操作可以在加载图像数据时对其进行预处理和增强。

常见变换:

  1. ToTensor:将PIL图像或NumPy ndarray转换为Tensor,并自动将像素值从[0, 255]范围归一化到[0, 1],将颜色通道从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式。参数可以为PIL图片或者Ndarray数组

PIL图像或NumPy数组到Tensor:

  • 对于PIL图像,ToTensor 会将其转换为NumPy数组。
  • 对于NumPy数组,ToTensor 会直接使用该数组。

调整颜色通道顺序:

  • 输入图像的形状通常是 (H, W, C),其中 H 是高度,W 是宽度,C 是通道数(例如,RGB图像的 C 为3)。
  • ToTensor 会将形状转换为 (C, H, W),这是PyTorch张量表示图像的默认格式。

归一化像素值:

  • 像素值通常在 [0, 255] 范围内。
  • ToTensor 会将像素值除以255,将其归一化到 [0, 1) 范围。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torchvision import transforms
from PIL import Image
import torch

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建ToTensor变换
to_tensor_transform = transforms.ToTensor()

# 应用变换
tensor_img = to_tensor_transform(img)

# tensor_img现在是形状为(C, H, W)的Tensor,像素值在[0, 1)范围内
print(tensor_img)
print(tensor_img.shape)
  1. Normalize:用均值(mean)和标准差(std)对图像的每个通道进行归一化,使得每个通道的平均值和标准差为mean[i]std[i]。这种处理可以消除图像数据中的量纲差异,使得模型训练更加稳定,参数为meanstd
  • mean(sequence):一个序列,表示每个通道的均值。序列的长度应该与图像的通道数相匹配

  • std(sequence):一个序列,表示每个通道的标准差。序列的长度也应该与图像的通道数相匹配

Normalize 对象期望输入是一个Tensor,且形状为 [C, H, W],其中 C 是通道数,H 是高度,W 是宽度。通常在应用 Normalize 之前,会先使用 ToTensor 变换将PIL图像或NumPy数组转换为Tensor。

在应用 Normalize 之前,确保图像数据的像素值已经转换为合适的范围,通常是通过 ToTensor 变换将像素值从 [0, 255] 转换到 [0, 1)

均值和标准差的来源:均值和标准差通常是通过对大量训练图像进行统计得到的。在许多预训练模型中,这些值是预先计算好的,并且作为模型的一部分提供。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torchvision import transforms
from PIL import Image
import torch

# 加载一张图像
img = Image.open('path_to_image.jpg')

# 将图像转换为Tensor
to_tensor_transform = transforms.ToTensor()
tensor_img = to_tensor_transform(img)

# 定义Normalize变换
normalize_transform = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)

# 应用变换
normalized_img = normalize_transform(tensor_img)

# normalized_img现在是标准化后的图像Tensor
  1. Resize:调整图像的大小到指定的尺寸,尺寸可以是单个整数,表示图像的短边将被调整到该值长边将按比例缩放;也可以是一个元组,表示图像将被调整到指定的宽度和高度。参数:

    • size (int or sequence): 指定调整后的图像尺寸。如果是一个整数,那么图像的短边将被调整到这个值,长边将按比例缩放。如果是一个元组,那么图像将被调整到指定的宽度和高度。

    • interpolation (int, optional): 指定插值方法。默认是 PIL.Image.BILINEAR。其他可选的插值方法包括 PIL.Image.NEARESTPIL.Image.BICUBIC 等。

      • NEAREST: 最近邻插值,速度最快,但质量最差
      • BILINEAR: 双线性插值,平衡了速度和质量
      • BICUBIC: 双三次插值,质量更好,但速度较慢
      • LANCZOS: Lanczos插值,高质量,适用于放大图像

Resize对象输入的为PIL图像或Ndarray数组

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torchvision import transforms
from PIL import Image

# 加载一张图像
img = Image.open('path_to_image.jpg')

# 定义Resize变换,将图像的短边调整到256
resize_transform = transforms.Resize(256)

# 应用变换
resized_img = resize_transform(img)

# 显示调整后的图像
resized_img.show()

# 定义Resize变换,将图像调整到宽256,高128
resize_transform = transforms.Resize((256, 128))

# 应用变换
resized_img = resize_transform(img)

# 显示调整后的图像
resized_img.show()
  1. CenterCrop:从图像中心裁剪出指定大小的区域。参数:
    • size:一个表示裁剪尺寸的元组 (height, width),或者一个表示正方形裁剪尺寸的整数。如果是一个整数,那么裁剪区域的高度和宽度将相等

CenterCrop对象的输入为PIL图像或Ndarray数组,如果输入是NumPy数组,它应该是一个形状为 (H, W, C) 的数组。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建CenterCrop变换,裁剪出大小为(224, 224)的区域
center_crop_transform = transforms.CenterCrop((224, 224))

# 应用变换
cropped_img = center_crop_transform(img)

# cropped_img现在是裁剪后的图像,尺寸为(224, 224)
cropped_img.show()
  1. RandomCrop:随机裁剪出指定大小的区域。参数:
    • size:一个表示裁剪尺寸的元组 (height, width),或者一个表示正方形裁剪尺寸的整数。如果是一个整数,那么裁剪区域的高度和宽度将相等
    • padding(可选):一个表示填充大小的元组 (padding_left, padding_top, padding_right, padding_bottom),或者一个表示等边填充的整数。填充用于在裁剪前增加图像的边界,以便于在边界附近进行裁剪
    • pad_if_needed(可选):布尔值,表示是否只在需要时进行填充(即当原始图像尺寸小于裁剪尺寸时)
    • fill(可选):填充颜色,默认为0。用于指定填充区域的像素值
    • padding_mode(可选):填充模式,可以是 constantedgereflectsymmetric

RandomCrop对象输入的为PIL图像或Ndarray数组,如果输入是NumPy数组,它应该是一个形状为 (H, W, C) 的数组。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建RandomCrop变换,裁剪出大小为(224, 224)的区域
random_crop_transform = transforms.RandomCrop((224, 224), padding=10, fill=0, padding_mode='constant')

# 应用变换
cropped_img = random_crop_transform(img)

# cropped_img现在是随机裁剪后的图像,尺寸为(224, 224)
cropped_img.show()
  1. RandomHorizontalFlip:以一定的概率随机水平翻转图像。参数:
    • p:一个介于0和1之间的浮点数,表示图像被水平翻转的概率。默认值为0.5,即图像有50%的概率被翻转

RandomHorizontalFlip对象可以处理PIL图像和NumPy数组。如果输入是NumPy数组,它应该是一个形状为 (H, W, C) 的数组

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建RandomHorizontalFlip变换,设置翻转概率为0.5
horizontal_flip_transform = transforms.RandomHorizontalFlip(p=0.5)

# 应用变换
flipped_img = horizontal_flip_transform(img)

# flipped_img现在是可能被水平翻转后的图像
flipped_img.show()
  1. RandomVerticalFlip:以一定的概率随机垂直翻转图像。参数:
    • p:一个介于0和1之间的浮点数,表示图像被垂直翻转的概率。默认值为0.5,即图像有50%的概率被翻转

RandomVerticalFlip 可以处理PIL图像和NumPy数组。如果输入是NumPy数组,它应该是一个形状为 (H, W, C) 的数组

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建RandomVerticalFlip变换,设置翻转概率为0.5
vertical_flip_transform = transforms.RandomVerticalFlip(p=0.5)

# 应用变换
flipped_img = vertical_flip_transform(img)

# flipped_img现在是可能被垂直翻转后的图像
flipped_img.show()
  1. RandomRotation: 随机旋转图像一定的角度。参数:
    • degrees:旋转的角度范围,可以是一个单一的数值(表示在 -degreesdegrees 之间随机选择),也可以是一个元组 (min, max)(表示在 minmax 之间随机选择)
    • resample:用于指定旋转时使用的重采样滤波器,默认为 PIL.Image.NEAREST。其他选项包括 PIL.Image.BILINEARPIL.Image.BICUBIC
    • expand:布尔值,表示是否扩大图像的边界以适应旋转后的图像。如果为 True,则输出图像的尺寸可能会比原始图像大
    • center:旋转的中心点,默认为图像的中心。可以指定为 (x, y) 形式的元组
    • fill:用于填充旋转后图像边缘的颜色,可以是一个整数或一个元组。默认为0,表示黑色

RandomRotation 可以处理PIL图像和NumPy数组。如果输入是NumPy数组,它应该是一个形状为 (H, W, C) 的数组

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 创建RandomRotation变换,设置旋转角度范围为-45到45度
rotation_transform = transforms.RandomRotation(degrees=(-45, 45))

# 应用变换
rotated_img = rotation_transform(img)

# rotated_img现在是可能被旋转后的图像
rotated_img.show()

组合变换:

  • Compose:将多个变换操作组合在一起,按顺序应用。参数
    • transforms:一个变换列表,列表中的每个元素都是一个图像变换对象

Compose 中的变换需要兼容输入图像的类型。例如,ToTensor 变换将PIL图像或NumPy数组转换为Tensor,之后的变换需要能够处理Tensor类型

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torchvision import transforms
from PIL import Image

# 定义一个图像变换序列
transform_sequence = transforms.Compose([
transforms.Resize((256, 256)), # 将图像缩放到256x256
transforms.RandomRotation(30), # 随机旋转图像,角度在-30到30度之间
transforms.CenterCrop(224), # 中心裁剪得到224x224的图像
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化图像
])

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 应用定义好的变换序列
transformed_img = transform_sequence(img)

# transformed_img现在是经过一系列变换后的图像Tensor

其他变换:

  1. Pad: 对图像进行填充。填充是指在图像的边界添加额外的像素,通常用于调整图像尺寸、保持图像的宽高比(在将图像缩放到特定尺寸时,通过填充可以保持原始图像的宽高比)或为后续的图像处理步骤(如卷积神经网络中的卷积操作)做准备。某些神经网络架构要求输入图像具有特定的尺寸,通过填充可以满足这些要求。参数:
    • padding:整数或元组,表示要添加的像素数。如果是一个整数,则所有边都添加相同数量的像素。如果是元组,则格式应为 (left, top, right, bottom),分别表示左、上、右、下边要添加的像素数
    • fill:填充像素的值,默认为0。可以是整数或元组,如果是元组,则应与图像的通道数相匹配
    • padding_mode:填充模式,默认为 constant。其他可选模式包括 edgereflectsymmetric

填充模式:

  • 'constant':用常量值填充,默认为0。
  • 'edge':用边缘像素值填充。
  • 'reflect':以边缘为轴进行反射填充。
  • 'symmetric':以边缘为轴进行对称填充。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 定义一个填充变换,所有边填充10个像素,填充值为0
pad_transform = transforms.Pad(padding=10, fill=0, padding_mode='constant')

# 应用填充变换
padded_img = pad_transform(img)

# padded_img现在是经过填充后的图像
  1. Grayscale: 将图像转换为灰度图。灰度图像只包含亮度信息,没有颜色信息,每个像素值表示灰度级别。参数:
    • num_output_channels:输出图像的通道数,默认为1,可以选1或3。如果设置为3,则输出图像将为3通道,但所有通道的值相同,相当于将灰度图扩展为伪彩色图

彩色图像到灰度图像的转换通常使用以下公式:

\[ \text{Grayscale} = 0.2989 \times R + 0.5870 \times G + 0.1140 \times B \] 其中 \(R, G, B\) 分别是红色、绿色和蓝色通道的像素值。这个公式是基于人眼对不同颜色敏感度的加权平均。

比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
from torchvision import transforms
from PIL import Image

# 加载一张PIL图像
img = Image.open('path_to_image.jpg')

# 定义一个灰度变换
grayscale_transform = transforms.Grayscale(num_output_channels=1)

# 应用灰度变换
grayscale_img = grayscale_transform(img)

# grayscale_img现在是灰度图像
  • RandomAffine: 随机应用仿射变换。

  • RandomPerspective: 随机应用透视变换。

  • RandomErasing: 随机擦除图像中的部分区域。

  • ColorJitter: 随机改变图像的亮度、对比度、饱和度和色调。

    • brightness:一个浮点数或元组,表示亮度调整的范围。如果是一个浮点数,则表示亮度调整的最大幅度(相对于原始亮度)。如果是元组 (min, max),则表示亮度调整的最小和最大幅度。值为0表示不调整亮度。
    • contrast:一个浮点数或元组,表示对比度调整的范围。解释与 brightness 类似。
    • saturation:一个浮点数或元组,表示饱和度调整的范围。解释与 brightness 类似。
    • hue:一个浮点数或元组,表示色相调整的范围。解释与 brightness 类似,但通常这个值较小,因为色相的调整对图像外观影响较大。

    比如:

    1
    transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)
    • brightness=0.5:亮度可以在原始亮度的 50% 到 150% 之间随机调整。
    • contrast=0.5:对比度可以在原始对比度的 50% 到 150% 之间随机调整。
    • saturation=0.5:饱和度可以在原始饱和度的 50% 到 150% 之间随机调整。
    • hue=0.1:色相可以在原始色相的 ±10% 之间随机调整。
  • RandomResizedCrop(size, scale, ratio):用于随机裁剪图像的一部分,并将其缩放到指定的大小。

    • size:一个整数或元组,表示输出图像的大小。如果是一个整数,则输出图像是正方形;如果是元组 (height, width),则输出图像是矩形。
    • scale:一个元组 (min_area, max_area),表示裁剪区域面积相对于原始图像面积的比例范围。例如,(0.8, 1.0) 表示裁剪区域的面积将在原始图像面积的 80% 到 100% 之间。
    • ratio:一个元组 (min_ratio, max_ratio),表示裁剪区域的宽高比范围。例如,(0.9, 1.1) 表示裁剪区域的宽高比将在 0.9 到 1.1 之间。

torchvision.model

torchvision.model是PyTorch的torchvision库中的一个模块,它提供了许多预训练的深度学习模型,这些模型主要用于图像分类、目标检测、语义分割等计算机视觉任务。这些模型都是基于PyTorch构建的,因此可以很容易地集成到PyTorch项目中。

常用模型:

  • 图像分类:(ImageNet数据集,\(224\times224\)尺寸图像)
    • resnet18, resnet34, resnet50, resnet101, resnet152:ResNet 系列。
    • alexnet:AlexNet。
    • vgg11, vgg13, vgg16, vgg19:VGG 系列。
    • squeezenet1_0, squeezenet1_1:SqueezeNet。
    • densenet121, densenet169, densenet201, densenet161:DenseNet 系列。
    • inception_v3:Inception v3。
    • googlenet:GoogLeNet。
    • shufflenet_1x1, shufflenet_2x2:ShuffleNet 系列。
    • mobilenet_v2:MobileNet v2。
    • resnext50_32x4d, resnext101_32x8d:ResNeXt 系列。
    • wide_resnet50_2, wide_resnet101_2:Wide ResNet 系列。
    • mnasnet0_5, mnasnet1_0:MNASNet 系列。
  • 目标检测
    • fasterrcnn_resnet50_fpn:Faster R-CNN with ResNet-50 FPN。
    • maskrcnn_resnet50_fpn:Mask R-CNN with ResNet-50 FPN。
    • keypointrcnn_resnet50_fpn:Keypoint R-CNN with ResNet-50 FPN。
  • 语义分割
    • fcn_resnet101:FCN with ResNet-101。
    • deeplabv3_resnet101:DeepLabV3 with ResNet-101。

比如:

1
2
3
import torchvision.models as models

model = models.resnet18(pretrained=False, progress=True, **kwargs)

参数说明:

  1. pretrained (布尔值,默认为False):
    • 如果设置为True,则下载并加载预训练的权重。
    • 如果设置为False,则不加载预训练权重,模型将以随机初始化的权重开始。
  2. progress (布尔值,默认为True):
    • 如果设置为True,则在下载预训练权重时显示进度条。
    • 如果设置为False,则不显示进度条。
  3. **kwargs:
    • 这是一个可变长度参数列表,允许我们传递额外的关键字参数给模型的构造函数。
    • 例如,我们可以通过kwargs传递num_classes来指定输出层的类别数。

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torchvision.models as models
import torch.nn as nn

# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)

# 微调:
# 替换模型的最后一层
num_classes = 10 # 假设新任务有10个类别
model.fc = nn.Linear(model.fc.in_features, num_classes)

# 只训练最后一层
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True

# 使用适当的优化器进行训练
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

torchvision.io

torchvision.ops


torchvision
https://blog.shinebook.net/2025/03/08/人工智能/pytorch/torchvision/
作者
X
发布于
2025年3月8日
许可协议