首页 > 分享 > 【图像分类】实战篇 (2)5分钟学会用迁移学习ResNet50训练自己的图像分类模型(鸟类识别为例)

【图像分类】实战篇 (2)5分钟学会用迁移学习ResNet50训练自己的图像分类模型(鸟类识别为例)

作者简介:热编程的贝贝,致力于C/C++、Java、Python等多编程语言,热爱跑步健身,喜爱音乐的一位博主。
本文收录于贝贝的日常汇报系列,大家有兴趣的可以看一看
相关专栏深度学习、目标检测系列等,大家有兴趣的可以看一看
C++零基础入门系列,Web入门篇系列正在发展中,喜欢Python、C++的朋友们可以关注一下哦!
如有需要此项目工程,请评论区留言哦 也可联系作者微信 Qwe1398276934

目录

前言

一、构建数据集

二、数据预处理、图像增强和图像标准化

三、模型构建

四、设置优化器、损失函数、学习率等超参数

五、训练和验证

六、预测

前言

此项目主要是小白入门教学,项目可以直接运行,需要请私信!!!

图像分类是深度学习图像处理领域最基本的识别任务。

先上结果 

一、构建数据集

在工程目录鸟类识别下,dataset文件夹下有train和test文件夹下分别放有八种类的鸟文件夹,分别以该类别的名称命名。

 在布谷鸟等鸟类文件夹下分别是自己类别若干的图片,

下面代码为加载自己的数据集,路径使用相对路径

import torchvision.datasets as dsets

trainpath = './dataset/train/'

valpath = './dataset/test/'

trainData = dsets.ImageFolder (trainpath, transform =traintransform ) # 读取训练集,标签就是train⽬录下的⽂件夹的名字,图像保存在格⼦标签下的⽂件夹⾥

valData = dsets.ImageFolder (valpath, transform =valtransform ) #读取验证集

二、数据预处理、图像增强和图像标准化

在图像输入网络之前,首先进行图像增强和图像标准化,有随机旋转、改变颜色、改变成统一大小,并转化成tensor格式,预处理和标准化之后的图片可以正常输入网络。

训练集需要图片增强,因为训练集在训练过程中需要更新模型参数,而验证集为了反向传播梯度,不需要更新模型参数,为了增强模型的泛化性能,引入图像增强。

import torchvision.transforms as transforms

#数据增强的方式

traintransform = transforms .Compose([

transforms .RandomRotation (20), #随机旋转角度

transforms .ColorJitter(brightness=0.1), #颜色亮度

transforms .Resize([224, 224]), #设置成224×224大小的张量

transforms .ToTensor(), # 将图⽚数据变为tensor格式

# transforms.Normalize(mean=[0.485, 0.456, 0.406],

# std=[0.229, 0.224, 0.225]),

])

valtransform = transforms .Compose([

transforms .Resize([224, 224]),

transforms .ToTensor(), # 将图⽚数据变为tensor格式

])

三、模型构建

在进行图像分类之前,需要构建一个模型,在本文中,为了加速收敛,使用torchvision中集成的模型resnet50,并引入预训练权重:

在这里模型的num_of _classes一定要改为自己要分类的数目,此处鸟类为八个文件夹,此处设置为8

import torchvision.models as models

model = models.resnet50(pretrained=True) #pretrained表⽰是否加载已经与训练好的参数

model.fc = torch.nn.Linear(2048, num_of_classes) #将最后的fc层的输出改为标签数量(如3),512取决于原始⽹络fc层的输⼊通道

model = model.to(device) # 如果有GPU,⽽且确认使⽤则保留;如果没有GPU,请删除

四、设置优化器、损失函数、学习率等超参数

criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

batch_size = 16

learning_rate = 1e-4

epoches = 30

num_of_classes=8

五、训练和验证

大概过程,将图片传入模型,优化器梯度清零,计算损失、反传梯度,更新参数,处理输出,计算分类真却的图片个数,最后计算训练精度。

将模型、优化器和损失函数传入函数train开始训练:

def train(model, optimizer, criterion):

model.train()

total_loss = 0

train_corrects = 0

for i, (image, label) in enumerate (tqdm(trainLoader)):

image = Variable(image.to(device))

label = Variable(label.to(device))

optimizer.zero_grad ()

target = model(image)

loss = criterion(target, label)

loss.backward()

optimizer.step()

total_loss += loss.item()

max_value , max_index = torch.max(target, 1)

pred_label = max_index.cpu().numpy()

true_label = label.cpu().numpy()

train_corrects += np.sum(pred_label == true_label)

return total_loss / float(len(trainLoader)), train_corrects / train_sum

以下为验证函数:

def evaluate(model, criterion):

model.eval()

corrects = eval_loss = 0

with torch.no_grad():

for image, label in tqdm(testLoader):

image = Variable(image.to(device))

label = Variable(label.to(device))

pred = model(image)

loss = criterion(pred, label)

eval_loss += loss.item()

max_value, max_index = torch.max(pred, 1)

pred_label = max_index.cpu().numpy()

true_label = label.cpu().numpy()

corrects += np.sum(pred_label == true_label)

return eval_loss / float(len(testLoader)), corrects, corrects / test_sum

六、预测

首先准备标签

list=['孔雀', '布谷鸟', '梅花雀', '燕子', '赤颈鹤', '鹦鹉', '麻雀', '黄鹂']

设置device

if(torch.cuda.is_available()):

device=torch.device('cuda')

else:

device=torch.device('cpu')

加载训练好的模型

model=torch.load("./resnet50_bird.pt",map_location=device)#加载模型

图像标准化,传入训练好的模型

model.eval()

transformer = transforms.Compose([

transforms.Resize((224,224)),

transforms.ToTensor(),

])

filename="./3_0.jpg"

image=Image.open(filename)

img=transformer(image)

img= img.unsqueeze(0).to(device)

pred=model(img)

处理预测的结果

如果这份博客对大家有帮助,希望各位给恒川一个免费的点赞作为鼓励,并评论收藏一下⭐,谢谢大家!!!
制作不易,如果大家有什么疑问或给恒川的意见,欢迎评论区留言。

完整工程代码链接:https://download.csdn.net/download/qq_46644680/89997623?spm=1001.2014.3001.5501

相关知识

机器学习花朵图像分类
基于深度学习和迁移学习的识花实践
度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码
基于深度学习的花卉图像分类识别模型研究
【基于PyTorch实现经典网络架构的花卉图像分类模型】
基于深度学习模型的花卉图像分类代码
如何yolov8训练使用——西红柿叶片病虫害分类数据集,1.4GB,超过2万张图像,共11大类别分类 西红柿数据集
基于深度学习特征的植物病虫害检测
使用迁移学习对花卉进行分类
深度学习实战:AlexNet实现花图像分类

网址: 【图像分类】实战篇 (2)5分钟学会用迁移学习ResNet50训练自己的图像分类模型(鸟类识别为例) https://m.huajiangbk.com/newsview1057224.html

所属分类:花卉
上一篇: 深度学习 细粒度图像识别 (fi
下一篇: 冬季月季清园了,石硫合剂使用有讲