首页 > 分享 > pytorch——AlexNet——训练花分类数据集

pytorch——AlexNet——训练花分类数据集

宝藏博主:霹雳吧啦Wz_太阳花的小绿豆_CSDN博客-深度学习,Tensorflow,软件安装领域博主

目录

数据集下载

 训练集与测试集划分

 “split_data.py”

 Alexnet讲解:

名称解读

1)过拟合: 

2) Dropout:

3)gpu

 1. model.py

2. train.py

先对训练集的预处理

1)transforms.ToTensor(), 

2) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

3)transforms.Compose()类详解:串联多个transform操作

导入、加载 训练集 

导入、加载 验证集

存储 索引:标签 的字典

训练过程

3. predict.py

数据集下载

http://download.tensorflow.org/example_images/flower_photos.tgz
包含 5 中类型(雏菊,蒲公英,玫瑰,向日葵,郁金香)的花,每种类型有600~900张图像

 训练集与测试集划分

因为这次数据集不在dataset中,因此需要自己划分

参考:deep-learning-for-image-processing/README.md at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub

1.先将数据集压缩包解压到data_set文件夹中的flower_data中

2.在data_set目录下执行 shift + 右键 打开 PowerShell ,

3.执行 “split_data.py” 分类脚本自动将数据集划分成 训练集train 和 验证集val

 “split_data.py”

import os

from shutil import copy

import random

def mkfile(file):

if not os.path.exists(file):

os.makedirs(file)

file_path = 'flower_data/flower_photos'

flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

mkfile('flower_data/train')

for cla in flower_class:

mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')

for cla in flower_class:

mkfile('flower_data/val/'+cla)

split_rate = 0.1

for cla in flower_class:

cla_path = file_path + '/' + cla + '/'

images = os.listdir(cla_path)

num = len(images)

eval_index = random.sample(images, k=int(num*split_rate))

for index, image in enumerate(images):

if image in eval_index:

image_path = cla_path + image

new_path = 'flower_data/val/' + cla

copy(image_path, new_path)

else:

image_path = cla_path + image

new_path = 'flower_data/train/' + cla

copy(image_path, new_path)

print("r[{}] processing [{}/{}]".format(cla, index+1, num), end="")

print()

print("processing done!")

 Alexnet讲解:

 

名称解读

1)过拟合: 

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价

我的理解是因为太贴合训练集结果,导致我们的程序过度解读特征,我们的训练后的模型不能很好的预测其他的数据,而几乎完美贴合测试集

2) Dropout:

Dropout说的简单一点就是:我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作(意思就是随机失活下一层的神经元)

3)gpu

gpu加速就不用说了,打游戏深有体会

具体图 

 1. model.py

import torch.nn as nn

import torch

class AlexNet(nn.Module):

def __init__(self, num_classes=1000, init_weights=False):

super(AlexNet, self).__init__()

self.features = nn.Sequential(

nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(48, 128, kernel_size=5, padding=2),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(128, 192, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.Conv2d(192, 192, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.Conv2d(192, 128, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

)

self.classifier = nn.Sequential(

nn.Dropout(p=0.5),

nn.Linear(128 * 6 * 6, 2048),

nn.ReLU(inplace=True),

nn.Dropout(p=0.5),

nn.Linear(2048, 2048),

nn.ReLU(inplace=True),

nn.Linear(2048, num_classes),

)

if init_weights:

self._initialize_weights()

def forward(self, x):

x = self.features(x)

x = torch.flatten(x, start_dim=1)

x = self.classifier(x)

return x

def _initialize_weights(self):

for m in self.modules():

if isinstance(m, nn.Conv2d):

nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

if m.bias is not None:

nn.init.constant_(m.bias, 0)

elif isinstance(m, nn.Linear):

nn.init.normal_(m.weight, 0, 0.01)

nn.init.constant_(m.bias, 0)

2. train.py

先对训练集的预处理

data_transform = {

"train": transforms.Compose([transforms.RandomResizedCrop(224),

transforms.RandomHorizontalFlip(p=0.5),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

"val": transforms.Compose([transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

1)transforms.ToTensor(), 

ToTesnor会数据归一化到均值为0,方差为1(是将数据除以255)

2) transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

本来就将数据缩小了,那为什么ToTensor后面加Normalize?

我找到很好的博文分享给大家

Normalize()是对数据按通道进行标准化,即减去均值,再除以方差

数据如果分布在(0,1)之间,可能实际的bias,就是神经网络的输入b会比较大,而模型初始化时b=0的,这样会导致神经网络收敛比较慢,经过Normalize后,可以加快模型的收敛速度
因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。
————————————————
版权声明:本文为CSDN博主「小研一枚」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_35027690/article/details/103742697

3)transforms.Compose()类详解:串联多个transform操作

 参考transforms.Compose()类详解:串联多个transform操作_migue_math-CSDN博客_transforms.compose

导入、加载 训练集 

但是这次的 花分类数据集 并不在 pytorch 的 torchvision.datasets. 中,因此需要用到datasets.ImageFolder()    来导入。

ImageFolder()返回的对象是一个包含数据集所有图像及对应标签构成的二维元组容器,支持索引和迭代,可作为torch.utils.data.DataLoader的输入

参考Pytorch 加载图像数据(ImageFolder和Dataloader)_陶将的博客-CSDN博客

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))

image_path = data_root + "/data_set/flower_data/"

train_dataset = datasets.ImageFolder(root=image_path + "/train",

transform=data_transform["train"])

train_num = len(train_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset,

batch_size=32,

shuffle=True,

num_workers=0)

导入、加载 验证集

validate_dataset = datasets.ImageFolder(root=image_path + "/val",

transform=data_transform["val"])

val_num = len(validate_dataset)

validate_loader = torch.utils.data.DataLoader(validate_dataset,

batch_size=32,

shuffle=True,

num_workers=0)

存储 索引:标签 的字典

这和爬虫爬取数据操作很像

flower_list = train_dataset.class_to_idx

cla_dict = dict((val, key) for key, val in flower_list.items())

json_str = json.dumps(cla_dict, indent=4)

with open('class_indices.json', 'w') as json_file:

json_file.write(json_str)

训练过程

net.train():训练过程中开启 Dropoutnet.eval(): 验证过程关闭 Dropout

net = AlexNet(num_classes=5, init_weights=True)

net.to(device)

loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters(), lr=0.0002)

save_path = './AlexNet.pth'

best_acc = 0.0

for epoch in range(10):

net.train()

running_loss = 0.0

time_start = time.perf_counter()

for step, data in enumerate(train_loader, start=0):

images, labels = data

optimizer.zero_grad()

outputs = net(images.to(device))

loss = loss_function(outputs, labels.to(device))

loss.backward()

optimizer.step()

running_loss += loss.item()

rate = (step + 1) / len(train_loader)

a = "*" * int(rate * 50)

b = "." * int((1 - rate) * 50)

print("rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")

print()

print('%f s' % (time.perf_counter()-time_start))

net.eval()

acc = 0.0

with torch.no_grad():

for val_data in validate_loader:

val_images, val_labels = val_data

outputs = net(val_images.to(device))

predict_y = torch.max(outputs, dim=1)[1]

acc += (predict_y == val_labels.to(device)).sum().item()

val_accurate = acc / val_num

if val_accurate > best_acc:

best_acc = val_accurate

torch.save(net.state_dict(), save_path)

print('[epoch %d] train_loss: %.3f test_accuracy: %.3f n' %

(epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

生成pth的模型文件 

3. predict.py

import torch

from model import AlexNet

from PIL import Image

from torchvision import transforms

import matplotlib.pyplot as plt

import json

data_transform = transforms.Compose(

[transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

img = Image.open("./01.jpeg")

plt.imshow(img)

img = data_transform(img)

img = torch.unsqueeze(img, dim=0)

try:

json_file = open('./class_indices.json', 'r')

class_indict = json.load(json_file)

except Exception as e:

print(e)

exit(-1)

model = AlexNet(num_classes=5)

model_weight_path = "./AlexNet.pth"

model.load_state_dict(torch.load(model_weight_path))

model.eval()

with torch.no_grad():

output = torch.squeeze(model(img))

predict = torch.softmax(output, dim=0)

predict_cla = torch.argmax(predict).numpy()

print(class_indict[str(predict_cla)], predict[predict_cla].item())

plt.show()

相关知识

Pytorch之AlexNet花朵分类
卷积神经网络训练花卉识别分类器
Alex
pytorch深度学习框架——实现病虫害图像分类
深度学习实战(1):花的分类任务|附数据集与源码
基于深度学习和迁移学习的识花实践
102类花卉分类数据集(已划分,有训练集、测试集、验证集标签)
基于深度学习的植物病害检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
基于深度学习的玉米病虫害检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
基于深度学习的花卉识别

网址: pytorch——AlexNet——训练花分类数据集 https://m.huajiangbk.com/newsview150519.html

所属分类:花卉
上一篇: 花生遇到根腐病怎么办?花生根腐病
下一篇: 春天适合养仙人球,这几款很不错,