首页 > 分享 > pytorch学习之加载预训练模型

pytorch学习之加载预训练模型

最新推荐文章于 2024-11-07 09:14:34 发布

WangR0120 于 2018-12-11 16:20:20 发布

版权声明: https://blog.csdn.net/weixin_41278720/article/details/80759933

pytorch自发布以来,由于其便捷性,赢得了越来越多人的喜爱。

Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

本文,我们讲述的是models,且只谈模型的加载。models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。

模型地址:https://github.com/pytorch/vision/tree/master/torchvision/models

官方文档:https://pytorch.org/docs/master/torchvision/models.html

我将加载的方法简单总结为以下四种:

1.直接加载预训练模型

import torchvision.models as models

resnet50 = models.resnet50(pretrained= True)

1

这样就导入了resnet50的预训练模型了。

如果只需要网络结构,不需要用预训练模型的参数来初始化,那么就是:

model =torchvision.models.resnet50(pretrained=False) 1

或者把resnet复制到自己的目录下,新建个model文件夹

可以参考下面的猫狗大战入门算法入门

https://github.com/JackwithWilshere/Kaggle-Dogs_vs_Cats_PyTorch

2.修改某一层

 以resnet为例,默认的是ImageNet的1000类,比如我们要做二分类,分类猫和狗

resnet.fc = nn.Linear(2048, 2) 1

resnet 第一层卷积的卷积核是7,我们可能想改成5,那么可以通过以下方法修改:

resnet.conv1 = nn.Conv2d( 3, 64,kernel_size= 5, stride= 2, padding= 3, bias= False)

1'

3.加载部分预训练模型

对于具体的任务,很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

resnet50 = models.resnet50(pretrained= True)

pretrained_dict =resnet50.state_dict()

model_dict = model.state_dict()

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

model_dict.update(pretrained_dict)

model.load_state_dict(model_dict)

1'

4. 加载自己的模型

其实这个是保存和恢复模型,比如我们训练好的模型保存,然后加载用于测试。

方法一(推荐):

第一种方法也是官方推荐的方法,只保存和恢复模型中的参数。

使用这种方法,我们需要自己导入模型的结构信息。

(1)保存

torch.save(model.state_dict(), PATH)

torch.save(resnet50.state_dict(), 'ckp/model.pth')

1

(2)恢复

model = ModelClass(*args, **kwargs)

model.load_state_dict(torch.load(PATH))

resnet=resnet50(pretrained= True)

resnet.load_state_dict(torch.load( 'ckp/model.pth'))

1

方法二:

使用这种方法,将会保存模型的参数和结构信息。

(1)保存

torch.save (model, PATH) 1

(2)恢复

model = torch.load(PATH) 1

参考资料:

1. https://zhuanlan.zhihu.com/p/25980324

2. http://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/

相关知识

pytorch实现迁移训练
【Pytorch神经网络实战案例】24 基于迁移学习识别多种鸟类(CUB
Python基于Pytorch Transformer实现对iris鸢尾花的分类预测,分别使用CPU和GPU训练
基于深度学习和迁移学习的识花实践
基于pytorch搭建神经网络的花朵种类识别(深度学习)
【基于PyTorch实现经典网络架构的花卉图像分类模型】
pytorch 花朵的分类识别
ResNet残差网络在PyTorch中的实现——训练花卉分类器
遇到问题:读取模型 strict=False的意思 model.load
【免费】基于pytorch的深度学习花朵种类识别项目完整教程(内涵完整文件和代码)

网址: pytorch学习之加载预训练模型 https://m.huajiangbk.com/newsview1286814.html

所属分类:花卉
上一篇: 长兴岛桔园基地
下一篇: 巴彦淖尔桔梗山东花海