利用Kaggle上的一个公开数据集,下载连接如下:
https://www.kaggle.com/datasets/alxmamaev/flowers-recognition
其是一些花的照片,共有5类,四千多张照片。
整个数据集并不大,因此可以将其先读入到内存(显存中),而不再需要每次要用到的时候再从硬盘中读取,能够有效地提升运行速度。
而图片的数量并不多,因此还需要用到图片增广技术。
Kaggle上的数据已经按照文件夹将图片分好类了,因此读取图片的时候,需要按照文件夹来归类。
class Flower_Dataset(Dataset): def __init__(self, path , is_train, augs): data_root = pathlib.Path(path) all_image_paths = list(data_root.glob('*/*')) self.all_image_paths = [str(path) for path in all_image_paths] label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir()) label_to_index = dict((label, index) for index, label in enumerate(label_names)) self.all_image = [cv.imread(path) for path in self.all_image_paths] self.all_image_labels = [label_to_index[path.parent.name] for path in all_image_paths] 123456789
考虑花的图片,水平变换之后仍然是一朵花,因此可以使用此种增广方式。
此为,亮度、对比度等调整均可使用。
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5) augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug]) 12
每次从数据集中抽取一个批量的大小。
一般情况下使用打乱顺序的方式。
train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers= 4) test_iter = DataLoader(test_set, batch_size=batch_size, num_workers= 4) 12
采用经典的resnet模型,由于数据集大小有限,不宜采用过于复杂的网络,故在此选用了resnet18,其共有68层,不算太深,具体结构如下:
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 112, 112] 9,408 BatchNorm2d-2 [-1, 64, 112, 112] 128 ReLU-3 [-1, 64, 112, 112] 0 MaxPool2d-4 [-1, 64, 56, 56] 0 Conv2d-5 [-1, 64, 56, 56] 36,864 BatchNorm2d-6 [-112345678
相关知识
使用PyTorch实现对花朵的分类
pytorch实现简单卷积神经网络(CNN)网络完成手写数字识别
(四)pytorch图像识别实战之用resnet18实现花朵分类(代码+详细注解)
Pytorch框架实战——102类花卉分类
深度学习实战:AlexNet实现花图像分类
ResNet残差网络在PyTorch中的实现——训练花卉分类器
pytorch深度学习框架——实现病虫害图像分类
植物病害检测系统:利用深度学习守护农田的科技先锋
CNN卷积神经网络:花卉分类
基于pytorch搭建AlexNet神经网络用于花类识别
网址: 【pytorch】CNN实战 https://m.huajiangbk.com/newsview516075.html
上一篇: 影响花艺作品的对比与调和的核心因 |
下一篇: 邻里课堂 |