首页 > 分享 > 03常用pytorch剪枝工具

03常用pytorch剪枝工具

常用剪枝工具

pytorch官方案例

import torch.nn.utils.prune as prune

import torch from torch import nn import torch.nn.utils.prune as prune import torch.nn.functional as F print(torch.__version__) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet().to(device=device) module = model.conv1 prune.random_structurd(module, name="weight", amount=0.3, dim=1) #对同一层进行连续不同的剪枝 prune.l1_unstructured(module, name="weight", amount=3) prune.l1_unstructured(module, name="bias", amount=3) prune.ln_structured(module, name="bias", amount=0.5, n=3, dim=0)

序列化剪枝后的模型

在PyTorch中,named_buffers()是一个模型的方法,它返回一个迭代器,这个迭代器包含了模型中所有持久化的缓冲区。在每次迭代中,它返回一个包含缓冲区名(name)和缓冲区的张量(tensor)的元组。

在神经网络中,有些数据虽然不是模型参数(也就是不会在反向传播中被更新),但是这些数据在前向传播过程中是需要的,这些数据就被称为缓冲区(buffer)。缓冲区通常用于存储不参与梯度计算,但需要在训练过程中持久化的数据。例如,批归一化(Batch Normalization)层中的运行平均值和运行方差就是存储在缓冲区中的。

对于剪枝操作来说,剪枝的掩码通常会被保存为一个缓冲区。这个掩码的作用是在前向传播过程中把被剪枝的权重(也就是被设为0的权重)从计算中排除出去。

所以,named_buffers()函数就是用来获取模型中所有缓冲区的名称和对应的数据。这在进行剪枝操作时,可以用来检查剪枝的掩码是否已经被正确地添加到模型中。

#state_dict()是一个PyTorch模型的方法,它返回一个字典,其中包含了模型的所有参数,包括权重和偏置。字典的键是参数的名称,值是参数的值。这个字典可以用于保存和加载模型的参数。 #keys()是Python字典的一个方法,它返回字典的所有键的列表。 #所以,model.state_dict().keys()返回的是一个包含模型中所有参数名称的列表。weight和bias print(model.state_dict().keys()) new_model = LeNet() #这行代码开始遍历模型中的所有模块(或层)。named_modules()函数返回一个迭代器,每次迭代返回一个包含模块名(name)和模块实例(module)的元组。 for name, module in new_model.named_modules(): # prune 20% of connections in all 2D-conv layers if isinstance(module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=0.2) # prune 40% of connections in all linear layers elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.4) print(dict(new_model.named_buffers()).keys()) # to verify that all masks exist

global pruning

model = LeNet() #第一个元素是model,第二个元素是这个model里哪一些参数要被剪掉 parameters_to_prune = ( (model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'), ) #进行全局无结构剪枝 prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, )

"Sparsity"(稀疏性)是一个数学概念,用于描述一个矩阵中零元素的比例。在深度学习中,稀疏性通常用来描述模型权重矩阵中零值的比例。

print( "Sparsity in conv1.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv1.weight == 0)) / float(model.conv1.weight.nelement()) ) ) print( "Sparsity in conv2.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv2.weight == 0)) / float(model.conv2.weight.nelement()) ) ) print( "Sparsity in fc1.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc1.weight == 0)) / float(model.fc1.weight.nelement()) ) ) print( "Sparsity in fc2.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc2.weight == 0)) / float(model.fc2.weight.nelement()) ) ) print( "Sparsity in fc3.weight: {:.2f}%".format( 100. * float(torch.sum(model.fc3.weight == 0)) / float(model.fc3.weight.nelement()) ) ) print( "Global sparsity: {:.2f}%".format( 100. * float( torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) ) / float( model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement() ) ) )

自定义pruning functions

下面是每隔一个就进行一次非结构化剪枝

自定义剪枝pytorch官方教程: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#:~:text=Global sparsity%3A 20.00%25-,Extending torch.nn.utils.prune with custom pruning,-functions

pytorch源码参考: https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/prune.py#:~:text=%40abstractmethod,method recipe.

#该类是prune.BasePruningMethod的子类 class ImplEveryOtherPruningMethod(prune.BasePruningMethod): #定义剪枝类型 PRUNING_TYPE = 'unstructured' #重写了基类中的抽象方法compute_mask。该方法接收两个参数,一个是待剪枝的张量t,另一个是默认的掩码default_mask。 def compute_mask(self, t, default_mask): #创建一个default_mask的副本,这是为了避免改变原始的default_mask。 mask = default_mask.clone() #这个操作首先将掩码的形状改为一维mask.view(-1),然后选择索引为偶数的所有元素[::2],将它们设置为0。这样就达到了每隔一个元素剪枝的效果。 mask.view(-1)[::2] = 0 return mask def Ieveryother_unstructured_prune(module, name): #生成一个想要的mask,并且apply到module的元素上 ImplEveryOtherPruningMethod.apply(module, name) return module model = LeNet() Ieveryother_unstructured_prune(model.fc3, name='bias') print(model.fc3.bias_mask)

相关知识

pytorch 花朵的分类识别
插花艺术常用的工具介绍
艺术插花常用的工具
PyTorch环境配置及安装
pytorch深度学习框架——实现病虫害图像分类
一种玫瑰花打刺剪枝工具的制作方法
【大虾送书第二期】《Python机器学习:基于PyTorch和Scikit
创建虚拟环境并,创建pytorch 1.3.1
西瓜如何剪枝
翻译: 3.3. 线性回归的简明实现 pytorch

网址: 03常用pytorch剪枝工具 https://m.huajiangbk.com/newsview1401057.html

所属分类:花卉
上一篇: 驱蚊草盆栽的养殖方法(室内盆栽驱
下一篇: worth 沃施 园艺花艺剪刀家