首页 > 分享 > pytorch图像分类篇: 花分类数据集下载和AlexNet网络搭建训练

pytorch图像分类篇: 花分类数据集下载和AlexNet网络搭建训练

一、花分类数据集下载

data_setdata_set该文件夹是用来存放训练数据的目录

使用步骤如下:

(1)在data_set文件夹下创建新文件夹"flower_data"

(2)点击链接下载花分类数据集

(3)解压数据集到flower_data文件夹下

(4)执行"flower_data.py"脚本自动将数据集划分成训练集train和验证集val

├── flower_data

├── flower_photos(解压的数据集文件夹,3670个样本)

├── train(生成的训练集,3306个样本)

└── val(生成的验证集,364个样本)

二、 AlexNet网络搭建训练

1、module.py

import torch.nn as nn

import torch

class AlexNet(nn.Module):

def __init__(self, num_classes=1000, init_weights=False):

super(AlexNet, self).__init__()

self.features = nn.Sequential(

nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(48, 128, kernel_size=5, padding=2),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(128, 192, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.Conv2d(192, 192, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.Conv2d(192, 128, kernel_size=3, padding=1),

nn.ReLU(inplace=True),

nn.MaxPool2d(kernel_size=3, stride=2),

)

self.classifier = nn.Sequential(

nn.Dropout(p=0.5),

nn.Linear(128 * 6 * 6, 2048),

nn.ReLU(inplace=True),

nn.Dropout(p=0.5),

nn.Linear(2048, 2048),

nn.ReLU(inplace=True),

nn.Linear(2048, num_classes),

)

if init_weights:

self._initialize_weights()

def forward(self, x):

x = self.features(x)

x = torch.flatten(x, start_dim=1)

x = self.classifier(x)

return x

def _initialize_weights(self):

for m in self.modules():

if isinstance(m, nn.Conv2d):

nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

if m.bias is not None:

nn.init.constant_(m.bias, 0)

elif isinstance(m, nn.Linear):

nn.init.normal_(m.weight, 0, 0.01)

nn.init.constant_(m.bias, 0)

2、train.py

import os

import sys

import json

import torch

import torch.nn as nn

from torchvision import transforms, datasets, utils

import matplotlib.pyplot as plt

import numpy as np

import torch.optim as optim

from tqdm import tqdm

from model import AlexNet

def main():

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print("using {} device.".format(device))

data_transform = {

"train": transforms.Compose([transforms.RandomResizedCrop(224),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

"val": transforms.Compose([transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))

image_path = os.path.join(data_root, "data_set", "flower_data")

assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),

transform=data_transform["train"])

train_num = len(train_dataset)

flower_list = train_dataset.class_to_idx

cla_dict = dict((val, key) for key, val in flower_list.items())

json_str = json.dumps(cla_dict, indent=4)

with open('class_indices.json', 'w') as json_file:

json_file.write(json_str)

batch_size = 32

nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])

print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,

batch_size=batch_size, shuffle=True,

num_workers=0)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),

transform=data_transform["val"])

val_num = len(validate_dataset)

validate_loader = torch.utils.data.DataLoader(validate_dataset,

batch_size=4, shuffle=True,

num_workers=0)

print("using {} images for training, {} images for validation.".format(train_num,

val_num))

net = AlexNet(num_classes=5, init_weights=True)

net.to(device)

loss_function = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters(), lr=0.0002)

epochs = 10

save_path = './AlexNet.pth'

best_acc = 0.0

train_steps = len(train_loader)

for epoch in range(epochs):

net.train()

running_loss = 0.0

train_bar = tqdm(train_loader, file=sys.stdout)

for step, data in enumerate(train_bar):

images, labels = data

optimizer.zero_grad()

outputs = net(images.to(device))

loss = loss_function(outputs, labels.to(device))

loss.backward()

optimizer.step()

running_loss += loss.item()

train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,

epochs,

loss)

net.eval()

acc = 0.0

with torch.no_grad():

val_bar = tqdm(validate_loader, file=sys.stdout)

for val_data in val_bar:

val_images, val_labels = val_data

outputs = net(val_images.to(device))

predict_y = torch.max(outputs, dim=1)[1]

acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

val_accurate = acc / val_num

print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %

(epoch + 1, running_loss / train_steps, val_accurate))

if val_accurate > best_acc:

best_acc = val_accurate

torch.save(net.state_dict(), save_path)

print('Finished Training')

if __name__ == '__main__':

main()

3、present.py

import os

import json

import torch

from PIL import Image

from torchvision import transforms

import matplotlib.pyplot as plt

from model import AlexNet

def main():

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(

[transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

img_path = "../tulip.jpg"

assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)

img = Image.open(img_path)

plt.imshow(img)

img = data_transform(img)

img = torch.unsqueeze(img, dim=0)

json_path = './class_indices.json'

assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:

class_indict = json.load(f)

model = AlexNet(num_classes=5).to(device)

weights_path = "./AlexNet.pth"

assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)

model.load_state_dict(torch.load(weights_path))

model.eval()

with torch.no_grad():

output = torch.squeeze(model(img.to(device))).cpu()

predict = torch.softmax(output, dim=0)

predict_cla = torch.argmax(predict).numpy()

print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],

predict[predict_cla].numpy())

plt.title(print_res)

for i in range(len(predict)):

print("class: {:10} prob: {:.3}".format(class_indict[str(i)],

predict[i].numpy()))

plt.show()

if __name__ == '__main__':

main()

三、结果

 

相关知识

pytorch——AlexNet——训练花分类数据集
深度学习实战:AlexNet实现花图像分类
基于pytorch搭建AlexNet神经网络用于花类识别
度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码
Pytorch之AlexNet花朵分类
【基于PyTorch实现经典网络架构的花卉图像分类模型】
卷积神经网络训练花卉识别分类器
pytorch深度学习框架——实现病虫害图像分类
探索昆虫世界:利用hymenoptera数据集提升图像分类能力
基于深度学习和迁移学习的识花实践

网址: pytorch图像分类篇: 花分类数据集下载和AlexNet网络搭建训练 https://m.huajiangbk.com/newsview373733.html

所属分类:花卉
上一篇: 浙江推进网络学习空间应用 新添5
下一篇: 江苏省南京市(2024年