PyTorch/Pytorch inferencce 与图像分类简述

PyTorch/Pytorch inferencce 与图像分类简述

这篇文章主要介绍了 图像分类的 inference。

以 ResNet18 为例。

首先加载训练好的模型参数:

1
2
3
4
5
6
7
8
9
resnet18 = models.resnet18()

# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)

# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])

然后比较重要的是把模型放到 GPU 上,并且转换到eval模式:

1
2
resnet18.to(device)
resnet18.eval()

在inference 时,代码要放在with torch.no_grad():下。torch.no_grad()会关闭反向传播,可以减少内存、加快速度。

根据路径读取图片,把图片转换为 tensor,然后使用unsqueeze_(0)方法把形状扩大为\(B \times C \times H \times W\),再把 tensor 放到 GPU 上 。模型的输出数据outputs的形状是\(1 \times 2\)torch.max(outputs,0)是返回outputs中每一列最大的元素和索引,torch.max(outputs,1)是返回outputs中每一行最大的元素和索引。这里使用_, pred_int = torch.max(outputs.data, 1)返回最大元素的索引,然后根据索引获得 label:pred_str = classes[int(pred_int)]。关键代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
with torch.no_grad():
for idx, img_name in enumerate(img_names):

path_img = os.path.join(img_dir, img_name)

# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')

# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)

# step 3/4 : tensor --> vector
outputs = resnet18(img_tensor)

# step 4/4 : get label
_, pred_int = torch.max(outputs.data, 1)
pred_str = classes[int(pred_int)]

全部代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import time
import torch.nn as nn
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
import torchvision.models as models
import enviroments
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# config
vis = True
# vis = False
vis_row = 4

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])

classes = ["ants", "bees"]


def img_transform(img_rgb, transform=None):
"""
将数据转换为模型读取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""

if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")

img_t = transform(img_rgb)
return img_t


def get_img_name(img_dir, format="jpg"):
"""
获取文件夹下format格式的文件名
:param img_dir: str
:param format: str
:return: list
"""
file_names = os.listdir(img_dir)
# 使用 list(filter(lambda())) 筛选出 jpg 后缀的文件
img_names = list(filter(lambda x: x.endswith(format), file_names))

if len(img_names) < 1:
raise ValueError("{}下找不到{}格式数据".format(img_dir, format))
return img_names


def get_model(m_path, vis_model=False):

resnet18 = models.resnet18()

# 修改全连接层的输出
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 2)

# 加载模型参数
checkpoint = torch.load(m_path)
resnet18.load_state_dict(checkpoint['model_state_dict'])


if vis_model:
from torchsummary import summary
summary(resnet18, input_size=(3, 224, 224), device="cpu")

return resnet18


if __name__ == "__main__":

img_dir = os.path.join(enviroments.hymenoptera_data_dir,"val/bees")
model_path = "./checkpoint_14_epoch.pkl"
time_total = 0
img_list, img_pred = list(), list()

# 1. data
img_names = get_img_name(img_dir)
num_img = len(img_names)

# 2. model
resnet18 = get_model(model_path, True)
resnet18.to(device)
resnet18.eval()

with torch.no_grad():
for idx, img_name in enumerate(img_names):

path_img = os.path.join(img_dir, img_name)

# step 1/4 : path --> img
img_rgb = Image.open(path_img).convert('RGB')

# step 2/4 : img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0)
img_tensor = img_tensor.to(device)

# step 3/4 : tensor --> vector
time_tic = time.time()
outputs = resnet18(img_tensor)
time_toc = time.time()

# step 4/4 : visualization
_, pred_int = torch.max(outputs.data, 1)
pred_str = classes[int(pred_int)]

if vis:
img_list.append(img_rgb)
img_pred.append(pred_str)

if (idx+1) % (vis_row*vis_row) == 0 or num_img == idx+1:
for i in range(len(img_list)):
plt.subplot(vis_row, vis_row, i+1).imshow(img_list[i])
plt.title("predict:{}".format(img_pred[i]))
plt.show()
plt.close()
img_list, img_pred = list(), list()

time_s = time_toc-time_tic
time_total += time_s

print('{:d}/{:d}: {} {:.3f}s '.format(idx + 1, num_img, img_name, time_s))

print("\ndevice:{} total time:{:.1f}s mean:{:.3f}s".
format(device, time_total, time_total/num_img))
if torch.cuda.is_available():
print("GPU name:{}".format(torch.cuda.get_device_name()))

总结一下 inference 阶段需要注意的事项:

  • 确保 model 处于 eval 状态,而非 trainning 状态
  • 设置 torch.no_grad(),减少内存消耗,加快运算速度
  • 数据预处理需要保持一致,比如 RGB 或者 rBGR

torchvision.model中,有很多封装好的模型。


可以分类 3 类:

  • 经典网络
    • alexnet
    • vgg
    • resnet
    • inception
    • densenet
    • googlenet
  • 轻量化网络
    • squeezenet
    • mobilenet
    • shufflenetv2
  • 自动神经结构搜索方法的网络
    • mnasnet

以 ResNet 为例:


一个残差块有2条路径\(F(x)\)\(x\)\(F(x)\)路径拟合残差,不妨称之为残差路径;\(x\)路径为identity mapping恒等映射,称之为shortcut。图中的⊕为element-wise addition,要求参与运算的\(F(x)\)\(x\)的尺寸要相同。

shortcut路径大致可以分成2种,取决于残差路径是否改变了feature map数量和尺寸,一种是将输入x原封不动地输出,另一种则需要经过\(1×1\)卷积来升维或者降采样,主要作用是将输出与\(F(x)\)路径的输出保持shape一致,对网络性能的提升并不明显,两种结构如下图所示,


ResNet 网络结构如下:


根据上图,所有的 ResNet 都可以表示为下面的代码。其中layer1layer2layer3layer4分别对应conv2conv3conv4conv5

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

ResNet 中,所有的basic block都没有pooling层,降采样是通过 conv 的 stride 实现的,具体分别在conv3conv4conv5的第一个 basic block 的第一个卷积层降采样一半,同时feature map数量增加1倍。_make_layer代码如下。首先判断 stride 是否为1,输入通道和输出通道是否相等。不相等则使用 1 X 1 的卷积改变大小和通道,再加上 bn 层,作为 downsample 层。然后添加第一个 basic block,把 downsample 层传给 BasicBlock 作为降采样的层。然后改变 通道数self.inplanes = planes * block.expansion,继续添加这个 layer 里接下来的 BasicBlock,不传 stride 参数,默认为 1,并且第二个 BasicBlock 的输入和输出通道数是相等的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
# 首先判断 stride 是否为1,输入通道和输出通道是否相等。不相等则使用 1 X 1 的卷积改变大小和通道 作为 downsample
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
# 然后添加第一个 basic block,把 downsample 传给 BasicBlock 作为降采样的层。
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
# 继续添加这个 layer 里接下来的 BasicBlock
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

输入的图片形状是\(224 \times 224\),首先经过conv1层输出为\(112 \times 112\)。在conv2中,先经过一个max pool缩放为\(56 \times 56\),然后经过两个 basic block 的堆叠,每个 basic block 的结构是conv->bn->relu->conv->bn->relu->residual connect,其中卷积操作采用same padding,不改变特征图的大小,最后连接一个残差连接。同理经过conv3conv4conv5,最后经过 average pool 和 fc 层得到 1000 分类。conv1conv2conv3conv4conv5称为layerResNet18名字中的18,是指网络层数之和。conv1为 1 层,conv2conv3conv4conv5均为 4 层,总共为 16 层,最后一层全连接层,\(总层数 = 1+ 4 \times 4 + 1 = 18\),依此类推。

ResNet18、ResNet34 的 basic block 都是一样的,只是每个 layer 里堆叠的 basic block 数量不一样。

basic block 的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 # 由于第一个卷积层可能需要降采样,所以使用传进来的 stride
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
# 第二个卷积层不使用传进来的 stride,默认为 1
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

可以看到在 basic block 里,只有第一个卷积层的 stride 会采用传进来的 stride,并且在forward()函数里会判断downsample是否为空,如果不为空则执行降采样操作,起始就是 \(1 \times 1\) 的卷积改变通道数和大小。最后再和输出做shortcut

ResNet34 在 PyTorch 中的定义如下:

1
2
3
def resnet34(pretrained=False, progress=True, **kwargs):
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)

[3, 4, 6, 3]是指每个 layer 中 basic block 的数量

从 ResNet50 开始, basic block 改为 bottle neck,每个 basic block 的结构是conv(64,64,1)->conv(64,64,3)->conv(64,256,1)->residual connect,以此类推。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out

评论