kaggle挑战赛题目,构造一个分类模型,准确的识别出图像中木薯叶子感染的具体疾病。
详情可以参考链接:Cassva Leaf Disease Classification
任务就是训练一个分类模型,能够准确的识别出图中木薯叶感染了哪种疾病。本次竞赛数据集中定义了5种类别:
{"0": "Cassava Bacterial Blight (CBB)", "1": "Cassava Brown Streak Disease (CBSD)", "2": "Cassava Green Mottle (CGM)", "3": "Cassava Mosaic Disease (CMD)", "4": "Healthy"}
解题思路针对图像分类问题,有很多成熟的框架可以使用。比如:VGG、ResNet、EfficientNet或者基于机器学习的分类方法,SVM、Bagging、Boosting等。
解决算法模型相关问题,大致可以分为三个步骤:
1、数据预处理
2、确定模型架构
3、模型验证与调优
数据是问题的抽象变现,算法模型是针对数据表现出的问题提炼出的解决方法。毫不夸张的说,任何与算法模型相关的问题实际上都是分析数据、提炼数据的问题。
数据的预处理包括但不限于数据的收集、标注、增强、去噪声等。由于kaggle比赛会直接提供已标注的数据集,所以可以跳过数据收集和标注过程,专注于分析数据特征,根据预分析得到的数据特征,考虑数据增强方法和去噪方法。
经过博主分析统计,以及参考其他人分享的经验,数据中存在标注噪声,所以Label smoothing方法应该被考虑,同时数据中存在类别不均衡的现象,5中类别的样本数量不相等,所以类别不均衡现象也应该被考虑,focal-loss是解决类别不均衡问题的一个可选方法。
一、第一次实验
利用ResNet50,训练30个epoch,学习率设置为0.05,optimizer=Adam,提交acc=0.699。
ResNet网络结构。
class ResNet(nn.Module):
def __init__(self,blocks,num_classes=5,expansion=4):
super(ResNet,self).__init__()
self.expansion = expansion
self.conv1 = Conv1(in_planes=3,places=64)
self.layer1 = self.make_layer(in_places=64,places=64,block=blocks[0],stride=1)
self.layer2 = self.make_layer(in_places=256,places=128,block=blocks[1],stride=2)
self.layer3 = self.make_layer(in_places=512,places=256,block=blocks[2],stride=2)
self.layer4 = self.make_layer(in_places=1024,places=512,block=blocks[3],stride=2)
self.avgpool = nn.AvgPool2d(7,stride=1)
self.fc = nn.Linear(2048,num_classes)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
def make_layer(self,in_places,places,block,stride):
layers = []
layers.append(Bottleneck(in_places,places,stride,downsampling=True))
for i in range(1,block):
layers.append(Bottleneck(places*self.expansion,places))
return nn.Sequential(*layers)
def forward(self,x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
return x
ResNet50、ResNet101、ResNet152区别在于layer1,layer2,layer3,layer4中block数目。所以,利用上面的结构可以同时得到实例化的ResNet50、ResNet101、ResNet152。
def ResNet50():
return ResNet([3,4,6,3])
def ResNet101():
return ResNet([3,4,23,3])
def ResNet152():
return ResNet([3,8,36,3])
'submit之后的acc不到70%,很明显是模型没有收敛到全局最优,可能的结果是学习率设置不合理,过小还未收敛,过大跳过全局最优点。
所以设置动态学习率,同时在此基础上继续resume learning。
二、第二次实验
利用第一次实验的结果进行Resume Learning,optimizer=SGD,引入lr_scheduler,lr范围设置为[0.005,0.05],训练30个epoch,提交acc=0.810.
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.005, max_lr=0.05)
30个epoch之后,loss不在下降,在0.5附近一直震荡,推测可能学习率设置过大,可以设置为0.01进行实验。
三、第三次实验
前两次实验没有使用太多图像增强技术,本次实验引入大量图像增强技术,研究图像增强对模型性能的影响。
推荐使用albumentations,该方法集成了大量图像增强算法,可以直接嵌入到数据预处理阶段。
本次实验中使用的图像增强方法:
def get_inference_transforms():
return Compose([
RandomResizedCrop(224,224),
Transpose(p=0.5),
HorizontalFlip(p=0.5),
VerticalFlip(p=0.5),
HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
ToTensorV2(p=1.0),
],p=1.)
'在做推理时,需要注意,测试图像必须要经过相同的图像增强方法处理,否则提交的acc会很低。
训练60个epoch,依然使用CyclicLR进行动态调整lr,参数设置与实验2相同,最后提交的acc=0.835
通过实验三可以得出适当的进行图像增强可以提高模型的性能。
四、第四次实验
引入k-flod validation方法训练模型。
k折交叉验证在传统机器学习模型中有着广泛的应用,在缺乏大量训练样本时,该算法可以提高在有限数据集上训练得到模型的泛化能力。
推荐使用sklearn model_selection 中的StratifiedKFold方法,k设置为5。每一折的epochs均设置为60,初始学习率设置为1e-2,学习率调整策略设置为MultiStepLR。沿用实验三中的图像增强处理,提交acc=0.801。
对比实验三,实验四的模型复杂程度提高,但是提交测试准确率下降,推测是发生了过拟合,或者是学习率设置不合理。
五、第五次实验
使用snapmix增强方法,使用pretrained参数,图像尺寸设置为512,5folds训练策略。
在第一个fold训练完之后,取得目前位置最好的提交结果,acc=0.885。
由于图像 尺寸变大,每个epoch的训练时间比前四次实验都要长。
使用第二个fold训练acc是0.910的参数提交acc=0.879
使用第三个flod训练acc是0.950的参数提交acc=0.874
初步分析结果是发生了过拟合。
最后提交LeaderBoard得分0.886
模型架构
Input -> ResNet50 + ResNet101 -> classes
训练策略
1、snapmix数据融合方法
2、少量数据增强方法:归一化,旋转,反转,随机裁剪等
3、交叉验证
4、Adam + multistepLR
kaggle 最后提交代码链接 kaggle提交
TTA(测试时增强)
这里会为原始图像造出多个不同版本,包括不同区域裁剪和更改缩放程度等,并将它们输入到模型中;然后对多个版本进行计算得到平均输出,作为图像的最终输出分数。
相关知识
CropNet: Cassava Disease Detection
农作物病虫害识别进展概述
最全 农作物病害数据集汇总(不定期更新)
作物病虫害识别数据集资源合集
应用卷积神经网络识别花卉及其病症
Plant disease identification method based on lightweight CNN and mobile application
面向大规模多类别的病虫害识别模型
基于神经结构搜索的多种植物叶片病害识别
基于深度学习的农作物病害图像识别技术进展
Research Progress in Detection and Identification of Crop Diseases and Insect Pests Based on Deep Learning
网址: Kaggle Cassava Leaf Disease Classification 木薯叶疾病分类竞赛 https://m.huajiangbk.com/newsview484006.html
上一篇: 梨树病虫害识别与防治汇编.ppt |
下一篇: 某自花受粉的二倍体植物(2n=1 |