作者简介:热编程的贝贝,致力于C/C++、Java、Python等多编程语言,热爱跑步健身,喜爱音乐的一位博主。
本文收录于贝贝的日常汇报系列,大家有兴趣的可以看一看
相关专栏深度学习、目标检测系列等,大家有兴趣的可以看一看
C++零基础入门系列,Web入门篇系列正在发展中,喜欢Python、C++的朋友们可以关注一下哦!
如有需要此项目工程,请评论区留言哦 也可联系作者微信 Qwe1398276934
目录
前言
一、构建数据集
二、数据预处理、图像增强和图像标准化
三、模型构建
四、设置优化器、损失函数、学习率等超参数
五、训练和验证
六、预测
此项目主要是小白入门教学,项目可以直接运行,需要请私信!!!
图像分类是深度学习图像处理领域最基本的识别任务。
先上结果
在工程目录鸟类识别下,dataset文件夹下有train和test文件夹下分别放有八种类的鸟文件夹,分别以该类别的名称命名。
在布谷鸟等鸟类文件夹下分别是自己类别若干的图片,
下面代码为加载自己的数据集,路径使用相对路径
import torchvision.datasets as dsets
trainpath = './dataset/train/'
valpath = './dataset/test/'
trainData = dsets.ImageFolder (trainpath, transform =traintransform ) # 读取训练集,标签就是train⽬录下的⽂件夹的名字,图像保存在格⼦标签下的⽂件夹⾥
valData = dsets.ImageFolder (valpath, transform =valtransform ) #读取验证集
在图像输入网络之前,首先进行图像增强和图像标准化,有随机旋转、改变颜色、改变成统一大小,并转化成tensor格式,预处理和标准化之后的图片可以正常输入网络。
训练集需要图片增强,因为训练集在训练过程中需要更新模型参数,而验证集为了反向传播梯度,不需要更新模型参数,为了增强模型的泛化性能,引入图像增强。
import torchvision.transforms as transforms
#数据增强的方式
traintransform = transforms .Compose([
transforms .RandomRotation (20), #随机旋转角度
transforms .ColorJitter(brightness=0.1), #颜色亮度
transforms .Resize([224, 224]), #设置成224×224大小的张量
transforms .ToTensor(), # 将图⽚数据变为tensor格式
# transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225]),
])
valtransform = transforms .Compose([
transforms .Resize([224, 224]),
transforms .ToTensor(), # 将图⽚数据变为tensor格式
])
在进行图像分类之前,需要构建一个模型,在本文中,为了加速收敛,使用torchvision中集成的模型resnet50,并引入预训练权重:
在这里模型的num_of _classes一定要改为自己要分类的数目,此处鸟类为八个文件夹,此处设置为8
import torchvision.models as models
model = models.resnet50(pretrained=True) #pretrained表⽰是否加载已经与训练好的参数
model.fc = torch.nn.Linear(2048, num_of_classes) #将最后的fc层的输出改为标签数量(如3),512取决于原始⽹络fc层的输⼊通道
model = model.to(device) # 如果有GPU,⽽且确认使⽤则保留;如果没有GPU,请删除
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
batch_size = 16
learning_rate = 1e-4
epoches = 30
num_of_classes=8
大概过程,将图片传入模型,优化器梯度清零,计算损失、反传梯度,更新参数,处理输出,计算分类真却的图片个数,最后计算训练精度。
将模型、优化器和损失函数传入函数train开始训练:
def train(model, optimizer, criterion):
model.train()
total_loss = 0
train_corrects = 0
for i, (image, label) in enumerate (tqdm(trainLoader)):
image = Variable(image.to(device))
label = Variable(label.to(device))
optimizer.zero_grad ()
target = model(image)
loss = criterion(target, label)
loss.backward()
optimizer.step()
total_loss += loss.item()
max_value , max_index = torch.max(target, 1)
pred_label = max_index.cpu().numpy()
true_label = label.cpu().numpy()
train_corrects += np.sum(pred_label == true_label)
return total_loss / float(len(trainLoader)), train_corrects / train_sum
以下为验证函数:
def evaluate(model, criterion):
model.eval()
corrects = eval_loss = 0
with torch.no_grad():
for image, label in tqdm(testLoader):
image = Variable(image.to(device))
label = Variable(label.to(device))
pred = model(image)
loss = criterion(pred, label)
eval_loss += loss.item()
max_value, max_index = torch.max(pred, 1)
pred_label = max_index.cpu().numpy()
true_label = label.cpu().numpy()
corrects += np.sum(pred_label == true_label)
return eval_loss / float(len(testLoader)), corrects, corrects / test_sum
首先准备标签
list=['孔雀', '布谷鸟', '梅花雀', '燕子', '赤颈鹤', '鹦鹉', '麻雀', '黄鹂']
设置device
if(torch.cuda.is_available()):
device=torch.device('cuda')
else:
device=torch.device('cpu')
加载训练好的模型
model=torch.load("./resnet50_bird.pt",map_location=device)#加载模型
图像标准化,传入训练好的模型
model.eval()
transformer = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
])
filename="./3_0.jpg"
image=Image.open(filename)
img=transformer(image)
img= img.unsqueeze(0).to(device)
pred=model(img)
处理预测的结果
如果这份博客对大家有帮助,希望各位给恒川一个免费的点赞作为鼓励,并评论收藏一下⭐,谢谢大家!!!
制作不易,如果大家有什么疑问或给恒川的意见,欢迎评论区留言。
完整工程代码链接:https://download.csdn.net/download/qq_46644680/89997623?spm=1001.2014.3001.5501
相关知识
机器学习花朵图像分类
基于深度学习和迁移学习的识花实践
度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码
基于深度学习的花卉图像分类识别模型研究
【基于PyTorch实现经典网络架构的花卉图像分类模型】
基于深度学习模型的花卉图像分类代码
如何yolov8训练使用——西红柿叶片病虫害分类数据集,1.4GB,超过2万张图像,共11大类别分类 西红柿数据集
基于深度学习特征的植物病虫害检测
使用迁移学习对花卉进行分类
深度学习实战:AlexNet实现花图像分类
网址: 【图像分类】实战篇 (2)5分钟学会用迁移学习ResNet50训练自己的图像分类模型(鸟类识别为例) https://m.huajiangbk.com/newsview1057224.html
上一篇: 深度学习 细粒度图像识别 (fi |
下一篇: 冬季月季清园了,石硫合剂使用有讲 |