本文介绍了基于小波变换的池化方法——Wavelet Pooling,作为传统最大池化与平均池化的有效替代方案。该方法通过两级小波分解丢弃高频子带,保留更具代表性的低频特征,从而在减少信息丢失的同时提升模型的正则化能力。我们将 Wavelet Pool 和 UnPool 成功集成进 YOLOv11,替代原有的下采样与上采样模块,实现更高效的特征提取与恢复。实验证明,YOLOv11-WaveletPool 在多个分类与检测任务中均取得优异表现,展现了小波池化在深度学习中的广泛应用前景。
文章目录: YOLOv11改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总-CSDN博客
专栏链接: YOLOv11改进专栏
文章目录 前言介绍摘要 文章链接基本原理:**小波变换的基本原理****论文的方法** 核心代码YOLO11引入代码tasks注册步骤1:导包:步骤2 配置yolov11-WaveletPool.yaml实验脚本结果
卷积神经网络(Convolutional Neural Networks, CNNs)持续推动着二维和三维图像分类及目标识别技术的发展。然而,为了维持这一快速进展,有必要对神经网络中的基础构件进行持续的评估与改进。当前主流的网络正则化方法大多侧重于卷积操作本身,而对池化层的设计选择关注不足。
为此,我们提出了一种新的池化策略——小波池化(Wavelet Pooling),作为传统邻域池化方法(如最大池化和平均池化)的有效替代方案。该方法通过将特征分解为多层小波子带,并舍弃第一层级的高频子带来实现下采样,从而有效降低特征维度。与最大池化中常见的过拟合问题不同,小波池化在降维过程中保留了更多结构信息,具备更强的泛化能力。此外,相比于基于固定邻域的池化方式,小波池化在结构上实现了更紧凑、高效的特征压缩。
我们在四个标准图像分类数据集上进行了系统实验,结果表明:所提出的小波池化方法在性能上显著优于或与最大池化、平均池化、混合池化以及随机池化等主流方法相当,验证了其作为通用池化策略的潜力。
论文地址:论文地址
代码地址:代码地址
论文地址:论文地址
首先,池化是一种通过舍弃信息实现正则化效果的操作。然而,传统的池化方法存在一些不足:
Max pooling:当重要特征的幅度值低于不重要特征时,重要特征会被忽略。Average pooling:同时接纳幅值大和幅值小的特征,容易稀释关键特征。为了解决这些问题,该论文提出基于小波变换的池化操作,具体思路如下:
小波变换的基本原理小波变换可将输入特征图划分为低频子带(LL)和高频子带(LH、HL、HH)。其数学公式为:
一级小波变换:
L L 1 , L H 1 , H L 1 , H H 1 = D W T ( I ) LL1, LH1, HL1, HH1 = DWT(I) LL1,LH1,HL1,HH1=DWT(I)
逆变换:
I = I D W T ( L L 1 , L H 1 , H L 1 , H H 1 ) I = IDWT(LL1, LH1, HL1, HH1) I=IDWT(LL1,LH1,HL1,HH1)
二级小波变换:
L L 2 , L H 2 , H L 2 , H H 2 = D W T ( L L 1 ) LL2, LH2, HL2, HH2 = DWT(LL1) LL2,LH2,HL2,HH2=DWT(LL1)
逆变换:
L L 1 = I D W T ( L L 2 , L H 2 , H L 2 , H H 2 ) LL1 = IDWT(LL2, LH2, HL2, HH2) LL1=IDWT(LL2,LH2,HL2,HH2)
小波变换通过下采样将特征图尺寸缩小一半,逆变换可完美重建原始图像。
论文的方法该论文方法流程如下:
对输入图像 I I I 进行两次小波变换,得到:class WaveletPool(nn.Module): def __init__(self): super(WaveletPool, self).__init__() ll = np.array([[0.5, 0.5], [0.5, 0.5]]) lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], hl[None,::-1,::-1], hh[None,::-1,::-1]], axis=0) self.weight = nn.Parameter( torch.tensor(filts).to(torch.get_default_dtype()), requires_grad=False) def forward(self, x): C = x.shape[1] filters = torch.cat([self.weight,] * C, dim=0) y = F.conv2d(x, filters, groups=C, stride=2) return y class WaveletUnPool(nn.Module): def __init__(self): super(WaveletUnPool, self).__init__() ll = np.array([[0.5, 0.5], [0.5, 0.5]]) lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], axis=0) self.weight = nn.Parameter( torch.tensor(filts).to(torch.get_default_dtype()), requires_grad=False) def forward(self, x): C = torch.floor_divide(x.shape[1], 4) filters = torch.cat([self.weight, ] * C, dim=0) y = F.conv_transpose2d(x, filters, groups=C, stride=2) return y
py
123456789101112131415161718192021222324252627282930313233343536373839在根目录下的ultralytics/nn/目录,新建一个 otherModules目录,然后新建一个以 WaveletPool为文件名的py文件, 把代码拷贝进去。
import torch from torch import nn as nn import torch.nn.functional as F import numpy as np class WaveletPool(nn.Module): def __init__(self): super(WaveletPool, self).__init__() ll = np.array([[0.5, 0.5], [0.5, 0.5]]) lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) filts = np.stack([ll[None,::-1,::-1], lh[None,::-1,::-1], hl[None,::-1,::-1], hh[None,::-1,::-1]], axis=0) self.weight = nn.Parameter( torch.tensor(filts).to(torch.get_default_dtype()), requires_grad=False) def forward(self, x): C = x.shape[1] filters = torch.cat([self.weight,] * C, dim=0) y = F.conv2d(x, filters, groups=C, stride=2) return y class WaveletUnPool(nn.Module): def __init__(self): super(WaveletUnPool, self).__init__() ll = np.array([[0.5, 0.5], [0.5, 0.5]]) lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], axis=0) self.weight = nn.Parameter( torch.tensor(filts).to(torch.get_default_dtype()), requires_grad=False) def forward(self, x): C = torch.floor_divide(x.shape[1], 4) filters = torch.cat([self.weight, ] * C, dim=0) y = F.conv_transpose2d(x, filters, groups=C, stride=2) return y
py
1234567891011121314151617181920212223242526272829303132333435363738394041424344在ultralytics/nn/tasks.py中进行如下操作:
from ultralytics.nn.otherModules.WaveletPool import WaveletPool, WaveletUnPool
py
1修改def parse_model(d, ch, verbose=True):
只需要添加截图中标明的,其他没有的模块不用添加
elif m in {WaveletPool}: c2 = ch[f] * 4 elif m in {WaveletUnPool}: c2 = ch[f] // 4
py
1234
ultralytics/cfg/models/11/yolov11-WaveletPool.yaml
# Ultralytics YOLO , AGPL-3.0 license # YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # Parameters nc: 80 # number of classes scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n' # [depth, width, max_channels] n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs # YOLO11n backbone backbone: # [from, repeats, module, args] - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 - [-1, 2, C3k2, [256, False, 0.25]] - [-1, 1, WaveletPool, []] # 3-P3/8 - [-1, 2, C3k2, [512, False, 0.25]] - [-1, 1, WaveletPool, []] # 5-P4/16 - [-1, 2, C3k2, [512, True]] - [-1, 1, WaveletPool, []] # 7-P5/32 - [-1, 2, C3k2, [1024, True]] - [-1, 1, SPPF, [1024, 5]] # 9 - [-1, 2, C2PSA, [1024]] # 10 # YOLO11n head head: - [-1, 1, WaveletUnPool, []] - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [-1, 2, C3k2, [512, False]] # 13 - [-1, 1, WaveletUnPool, []] - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) - [-1, 1, WaveletPool, []] - [[-1, 13], 1, Concat, [1]] # cat head P4 - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) - [-1, 1, WaveletPool, []] - [[-1, 10], 1, Concat, [1]] # cat head P5 - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)
py
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748import warnings warnings.filterwarnings('ignore') from ultralytics import YOLO if __name__ == '__main__': # 修改为自己的配置文件地址 model = YOLO('/root/ultralytics-main/ultralytics/cfg/models/11/yolov11-WaveletPool.yaml') # 修改为自己的数据集地址 model.train(data='/root/ultralytics-main/ultralytics/cfg/datasets/coco8.yaml', cache=False, imgsz=640, epochs=10, single_cls=False, # 是否是单类别检测 batch=8, close_mosaic=10, workers=0, optimizer='SGD', # amp=True, project='runs/train', name='WaveletPool', )
py
1234567891011121314151617181920212223
相关知识
YOLOv11改进策略【小目标改进】
YOLOv11实战花卉图像识别
YOLOv11实战水果识别
YOLOv11 vs YOLOv8:谁才是真正的AI检测之王?
树皮纹理分析结合YOLOv11在树种智能识别中的应用.docx
基于YOLOv11的鲜花识别检测系统
YOLOv11昆虫诱捕24类识别系统:数据集、模型与界面
无人机助力违法毒品种植智能监测预警,基于YOLOv11全系列【n/s/m/l/x】参数模型开发构建无人机航拍场景下的农村田园场景下非法种植罂粟花检测预警识别系统
基于改进ResNet
改进的LeNet
网址: YOLOv11 改进 https://m.huajiangbk.com/newsview2480222.html
| 上一篇: peak地图有哪些 |
下一篇: 球兰品种 |