您的当前位置:首页正文

基于PyTorch实现一个简单的CNN图像分类器

2023-12-17 来源:九壹网
基于PyTorch实现⼀个简单的CNN图像分类器

⽬录

⼀. 加载数据

1. 继承Dataset类并重写关键⽅法2. 使⽤Dataloader加载数据⼆. 模型设计三. 训练四. 测试结语⼀. 加载数据

Pytorch的数据加载⼀般是⽤torch.utils.data.Dataset与torch.utils.data.Dataloader两个类联合进⾏。我们需要继承Dataset来定义⾃⼰的数据集类,然后在训练时⽤Dataloader加载⾃定义的数据集类。

1. 继承Dataset类并重写关键⽅法

pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常⽤的结构,⽽后者是当数据集难以(或不可能)进⾏随机读取时使⽤。在这⾥我们实现Map-style dataset。

继承torch.utils.data.Dataset后,需要重写的⽅法有:__len__与__getitem__⽅法,其中__len__⽅法需要返回所有数据的数量,⽽__getitem__则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个⽅法以外,⼀般还需要实现__init__⽅法来初始化⼀些变量。话不多说,直接上代码。

'''

包括了各种数据集的读取处理,以及图像相关处理⽅法'''

from torch.utils.data import Datasetimport torchimport osimport cv2

from Config import mycfgimport random

import numpy as np

class ImageClassifyDataset(Dataset):

def __init__(self, imagedir, labelfile, classify_num, train=True): '''

这⾥进⾏⼀些初始化操作。 '''

self.imagedir = imagedir self.labelfile = labelfile

self.classify_num = classify_num self.img_list = [] # 读取标签

with open(self.labelfile, 'r') as fp: lines = fp.readlines() for line in lines:

filepath = os.path.join(self.imagedir, line.split(\";\")[0].replace('\\\\', '/')) label = line.split(\";\")[1].strip('\\n')

self.img_list.append((filepath, label)) if not train:

self.img_list = random.sample(self.img_list, 50)

def __len__(self):

return len(self.img_list)

def __getitem__(self, item): '''

这个函数是关键,通过item(索引)来取数据集中的数据,

⼀般来说在这⾥才将图像数据加载⼊内存,之前存的是图像的保存路径 '''

_int_label = int(self.img_list[item][1]) # label直接⽤0,1,2,3,4...表⽰不同类别 label = torch.tensor(_int_label,dtype=torch.long) img = self.ProcessImgResize(self.img_list[item][0]) return img, label

def ProcessImgResize(self, filename): '''

对图像进⾏⼀些预处理 '''

_img = cv2.imread(filename)

_img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC) _img = _img.transpose((2, 0, 1)) _img = _img / 255

_img = torch.from_numpy(_img) _img = _img.to(torch.float32) return _img

有⼀些的数据集类⼀般还会传⼊⼀个transforms函数来构造⼀个图像预处理序列,传⼊transforms函数的⼀个好处是作为参数传⼊的话可以对⼀些⾮本地数据集中的数据进⾏操作(⽐如直接通过torchvision获取的⼀些预存数据集CIFAR10等等),除此之外就是torchvision.transforms⾥⾯有⼀些预定义的图像操作函数,可以直接像拼积⽊⼀样拼成⼀个图像处理序列,很⽅便。我这⾥因为是⽤我⾃⼰下载到本地的数据集,⽽且⽐较简单就直接⽤⾃⼰的函数来操作了。

2. 使⽤Dataloader加载数据

实例化⾃定义的数据集类ImageClassifyDataset后,将其传给DataLoader作为参数,得到⼀个可遍历的数据加载器。可以通过参数batch_size控制批处理⼤⼩,shuffle控制是否乱序读取,num_workers控制⽤于读取数据的线程数量。

from torch.utils.data import DataLoader

from MyDataset import ImageClassifyDataset

dataset = ImageClassifyDataset(imagedir, labelfile, 10)

dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5)for index, data in enumerate(dataloader): print(index) # batch索引

print(data) # ⼀个batch的{img,label}

⼆. 模型设计

在这⾥只讨论深度学习模型的设计,pytorch中的⽹络结构是⼀层⼀层叠出来的,pytorch中预定义了许多可以通过参数控制的⽹络层结构,⽐如Linear、CNN、RNN、Transformer等等具体可以查阅官⽅⽂档中的torch.nn部分。

设计⾃⼰的模型结构需要继承torch.nn.Module这个类,然后实现其中的forward⽅法,⼀般在__init__中设定好⽹络模型的⼀些组件,然后在forward⽅法中依据输⼊输出顺序拼装组件。

'''

包括了各种模型、⾃定义的loss计算⽅法、optimizer'''

import torch.nn as nn

class Simple_CNN(nn.Module): def __init__(self, class_num):

super(Simple_CNN, self).__init__() self.class_num = class_num self.conv1 = nn.Sequential(

nn.Conv2d( # input: 3,400,600 in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2 ),

nn.Conv2d(

in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2 ),

nn.AvgPool2d(2), # 16,400,600 --> 16,200,300 nn.BatchNorm2d(16), nn.LeakyReLU(), nn.Conv2d(

in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=2

),

nn.Conv2d(

in_channels=16, out_channels=8, kernel_size=5, stride=1, padding=2 ),

nn.AvgPool2d(2), # 8,200,300 --> 8,100,150 nn.BatchNorm2d(8), nn.LeakyReLU(), nn.Conv2d(

in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1 ),

nn.Conv2d(

in_channels=8, out_channels=1, kernel_size=3, stride=1, padding=1 ),

nn.AvgPool2d(2), # 1,100,150 --> 1,50,75 nn.BatchNorm2d(1), nn.LeakyReLU() )

self.line = nn.Sequential( nn.Linear(

in_features=50 * 75,

out_features=self.class_num ),

nn.Softmax() )

def forward(self, x): x = self.conv1(x)

x = x.view(-1, 50 * 75) y = self.line(x) return y

上⾯我定义的模型中包括卷积组件conv1和全连接组件line,卷积组件中包括了⼀些卷积层,⼀般是按照{卷积层、池化层、激活函数}的顺序拼接,其中我还在激活函数之前添加了⼀个BatchNorm2d层对上层的输出进⾏正则化以免传⼊激活函数的值过⼩(梯度消失)或过⼤(梯度爆炸)。

在拼接组件时,由于我全连接层的输⼊是⼀个⼀维向量,所以需要将卷积组件中最后的50 × 75 50\imes 7550×75⼤⼩的矩阵展平成⼀维的再传⼊全连接层(x.view(-1,50*75))

三. 训练

实例化模型后,⽹络模型的训练需要定义损失函数与优化器,损失函数定义了⽹络输出与标签的差距,依据不同的任务需要定义不同的合适的损失函数,⽽优化器则定义了神经⽹络中的参数如何基于损失来更新,⽬前神经⽹络最常⽤的优化器就是SGD(随机梯度下降算法) 及其变种。

在我这个简单的分类器模型中,直接⽤的多分类任务最常⽤的损失函数CrossEntropyLoss()以及优化器SGD。

self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM)

self.criterion = nn.CrossEntropyLoss() # 交叉熵,标签应该是0,1,2,3...的形式⽽不是独热的

self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)

训练过程其实很简单,使⽤dataloader依照batch读出数据后,将input放⼊⽹络模型中计算得到⽹络的输出,然后基于标签通过损失函数计算Loss,并将Loss反向传播回神经⽹络(在此之前需要清理上⼀次循环时的梯度),最后通过优化器更新权重。训练部分代码如下:

for each_epoch in range(mycfg.MAX_EPOCH): running_loss = 0.0 self.cnnmodel.train()

for index, data in enumerate(self.dataloader): inputs, labels = data

outputs = self.cnnmodel(inputs) loss = self.criterion(outputs, labels)

self.optimizer.zero_grad() # 清理上⼀次循环的梯度 loss.backward() # 反向传播 self.optimizer.step() # 更新参数 running_loss += loss.item() if index % 200 == 199:

print(\"[{}] loss: {:.4f}\".format(each_epoch, running_loss/200)) running_loss = 0.0 # 保存每⼀轮的模型

model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3)) torch.save(self.cnnmodel,model_name) # 保存全部模型

四. 测试

测试和训练的步骤差不多,也就是读取模型后通过dataloader获取数据然后将其输⼊⽹络获得输出,但是不需要进⾏反向传播的等操作了。⽐较值得注意的可能就是准确率计算⽅⾯有⼀些⼩技巧。

acc = 0.0count = 0

self.cnnmodel = torch.load('mymodel.pth')self.cnnmodel.eval()

for index, data in enumerate(dataloader_eval): inputs, labels = data # 5,3,400,600 5,10 count += len(labels)

outputs = cnnmodel(inputs)

_,predict = torch.max(outputs, 1)

acc += (labels == predict).sum().item()

print(\"[{}] accurancy: {:.4f}\".format(each_epoch, acc / count))

我这⾥采⽤的是保存全部模型并加载全部模型的⽅法,这种⽅法的好处是在使⽤模型时可以完全将其看作⼀个⿊盒,但是在模型⽐较⼤时这种⽅法会很费事。此时可以采⽤只保存参数不保存⽹络结构的⽅法,在每⼀次使⽤模型时需要读取参数赋值给已经实例化的模型:

torch.save(cnnmodel.state_dict(), \"my_resnet.pth\")cnnmodel = Simple_CNN()

cnnmodel.load_state_dict(torch.load(\"my_resnet.pth\"))

结语

⾄此整个流程就说完了,是⼀个⼩⽩级的图像分类任务流程,因为前段时间⼀直在做android⽅⾯的事,所以有点⽣疏了,就写了这篇博客记录⼀下,之后应该还会写⼀下seq2seq以及image caption任务⽅⾯的模型构造与训练过程,完整代码之后也会统⼀放到github上给⼤家做参考。

以上就是基于PyTorch实现⼀个简单的CNN图像分类器的详细内容,更多关于PyTorch实现CNN图像分类器的资料请关注其它相关⽂章!

因篇幅问题不能全部显示,请点此查看更多更全内容