https://blog.csdn.net/m0_47867638/article/details/143220101?sharetype=blogdetail&sharerId=143220101&sharerefer=PC&sharesource=m0_47867638&spm=1011.2480.3001.8118
论文介绍了一种新的WTConv模块,该模块通过利用小波变换有效地增加了卷积的感受野,并作为深度卷积的即插即用替代品在多个计算机视觉任务中表现出色。使用WTConv替换YoloV8的Conv模块有望带来类似的改进效果。
https://arxiv.org/pdf/2407.05848 近年来,人们尝试增大卷积神经网络(CNNs)的核大小,以模仿视觉转换器(ViTs)自注意力模块的全局感受野。然而,这种方法很快便达到了上限,并且在实现全局感受野之前就饱和了。在本文中,我们证明了通过利用小波变换(WT),实际上可以在不遭受过度参数化的情况下获得非常大的感受野,例如,对于感受野,所提出方法中可训练参数的数量仅随对数增长。所提出的层,命名为WTConv,可以作为现有架构中的即插即用替代品,产生有效的多频响应,并且随着感受野大小的增加而优雅地扩展。我们证明了WTConv层在ConvNeXt和MobileNetV2架构中用于图像分类以及作为下游任务的主干网络时的有效性,并展示了它带来了额外的特性,如对图像损坏的鲁棒性和对形状而非纹理的响应增强。我们的代码可在https://github.com/BGU-CS-VIL/WTConv上获得。
关键词:小波变换 感受野 多频
在过去十年中,卷积神经网络(CNNs)在很大程度上主导了计算机视觉的许多领域。尽管如此,随着最近视觉转换器(ViTs)[12]的出现,这是自然语言处理中使用的转换器架构[59]的一种改编,CNNs面临着激烈的竞争。具体而言,ViTs现在被认为相对于CNNs的优势主要归因于其多头自注意力层。该层促进了特征的全局混合,而卷积在结构上仅限于特征的局部混合。因此,最近的一些工作试图缩小CNNs和ViTs之间的性能差距。Liu等人[38]重构了ResNet架构及其训练流程,以跟上Swin Transformer[37]的步伐。其中一项改进是增加了卷积的核大小。然而,从经验上看,这种方法在核大小为时便饱和了,这意味着进一步增大核大小无济于事,甚至在某些时候开始恶化性能。虽然单纯地将大小增加到超过并无益处,但Ding等人[11]已经表明,如果内核构建得更好,则可以从中受益,即使内核更大。即便如此,最终内核会变得过度参数化,并且性能在达到全局感受野之前就饱和了。
[11]中分析的一个有趣特性是,使用更大的内核会使CNNs更具形状偏见,这意味着它们捕获图像中低频信息的能力得到了提高。这一发现有些令人惊讶,因为卷积层通常倾向于对输入中的高频信息作出响应[17,19,56,61]。这与注意力头不同,已知注意力头更适应于低频信息,如其他研究所示[44, 45, 56]。
上述讨论引出了一个自然的问题:我们能否利用信号处理工具来有效地增加卷积的感受野,同时避免过度参数化?换句话说,我们能否拥有非常大的滤波器(例如,具有全局感受野)并仍然提高性能?本文对这个问题给出了肯定的回答。我们提出的方法利用了小波变换(WT)[9],这是时频分析中的一个成熟工具,使卷积的感受野能够很好地扩展,并且通过级联,还能引导卷积神经网络(CNNs)更好地响应低频信息。部分地,我们基于WT提出解决方案的动机是(与例如傅里叶变换不同),它保留了一些空间分辨率。这使得在小波域中的空间操作(例如卷积)更有意义。
更具体地说,我们提出了WTConv层,该层使用级联WT分解并执行一组小核卷积,每个卷积都在越来越大的感受野中专注于输入的不同频率带。这个过程使我们能够更多地关注输入中的低频信息,同时只增加少量可训练参数。事实上,对于感受野,我们的可训练参数数量仅随对数增长。这一事实与一些最近的方法(其相应增长是二次的)相比,使我们能够获得具有前所未有的有效感受野(ERF)[40]大小的有效CNN(见图1)。
我们将WTConv设计为深度卷积的即插即用替代品,可以在任何给定的CNN架构中直接使用,而无需额外修改。我们通过将WTConv集成到ConvNeXt[38]中进行图像分类,验证了其有效性,证明了它在基本视觉任务中的实用性。进一步利用ConvNeXt作为主干网络,我们将评估扩展到更复杂的应用:在UperNet[65]中使用它进行语义分割,在Cascade Mask R-CNN[2]中使用它进行目标检测。此外,我们还分析了WTConv为CNN带来的额外好处。
综上所述,我们的主要贡献是:
小波变换(WT)[9]是信号处理和分析的有力工具,自20世纪80年代以来已被广泛使用。在经典设置中取得成功之后,最近WT也被纳入神经网络架构中,用于各种任务。Wang等人[63]从心电图(ECG)信号的时频分量中提取特征。Huang等人[32]和Guo等人[22]预测输入图像的小波高频系数,以重建更高分辨率的输出。Duan等人[13]和Williams与Li[64]将WT用作CNN中的池化算子。Gal等人[16]、Guth等人[23]和Phung等人[46]在生成模型中使用小波来增强生成图像的视觉质量,并提高计算性能。Finder等人[14]利用小波压缩特征图,以提高CNN的效率。Saragadam等人[51]将小波用作隐式神经表示的激活函数。
与我们的工作更相关的是,Liu等人[35]和Alaba等人[1]在修改后的U-Net架构[49]中使用WT进行下采样,并使用逆WT进行上采样。在另一项与我们的工作相关的工作中,Fujieda等人[15]提出了一种DenseNet类型的架构,该架构使用小波将输入中的较低频率重新引入到后续层中。虽然与小波无关,但Chen等人[3]提出通过对图像进行初步的高、低分辨率分离,并在网络中沿两个分辨率之间进行信息交换,来对多分辨率输入执行卷积。这些工作证明了将输入的低频分量与高频分量分开进行卷积的好处,以获得更具信息量的特征图。这一特性也激发了我们的工作。然而,来自[1,15,35]的方法是高度定制的架构,不能无缝地用于其他CNN架构,而[3]则侧重于计算效率。相比之下,我们提出了一个更轻量、更易用、线性的层,它可以作为深度卷积的即插即用替代品,并导致感受野的改善。重要的是,我们的方法可以适应任何使用深度卷积的网络,因此不限于单一任务。
在卷积配置方面,VGG[52]通过使用卷积,牺牲了单层感受野的大小来增加网络的深度(从不到10层到大约20层),为现代CNN树立了标准。从那时起,随着计算量的增加和架构的改进,CNN变得更深,但核大小参数在很大程度上仍未被探索。
传统卷积的一个重大变化是引入了可分离卷积【58,62】。可分离卷积由Xception【5】和MobileNet【30】推广,并被大多数现代架构【38,50】采用。在这种方法中,空间卷积是按通道(即深度方向)进行的,而跨通道操作是使用核(即逐点方式)进行的。这种卷积的分离也在核大小和通道维度之间(就参数数量和操作数量而言)产生了一定程度的分离。现在,每个具有核大小和通道的空间卷积仅有个参数(而不是),这使得它虽然仍然是二次方的,但能更好地随扩展。
同时,具有非局部自注意力层的Transformer在视觉任务【12,37】中的引入通常比局部混合卷积产生了更好的结果。这与上述最近对可分离卷积的使用一起,重新激发了人们探索更大卷积核用于卷积神经网络(CNN)的兴趣。特别是,Liu等人【38】重新审视了流行的ResNet架构【26】,包括对不同卷积核大小进行了实证比较,得出结论:性能在卷积核大小为时达到饱和。Trockman和Kolter【55】尝试仅使用卷积来模仿ViT架构,并通过使用卷积替换注意力(或“混合器”)组件展示了令人印象深刻的结果。Ding等人【11】提出,仅仅增加卷积核的大小会损害卷积的局部性属性。因此,他们建议并行使用一个小卷积核和一个大卷积核,然后将其输出相加。使用这种技术,他们成功训练了具有高达卷积核的CNN。Liu等人【36】通过将其分解为一组并行的和卷积核,成功地将卷积核大小增加到。此外,他们还引入了稀疏性,同时扩展了网络的宽度。然而,这种使用更多通道(具有稀疏性)的想法与增加卷积核大小是正交的。虽然我们的工作部分受到了【11,36】的启发,但在我们的情况下,所提出的层对输入的各个频率分量的输出进行求和,从而捕获了多个感受野。
实现全局感受野的另一种方法是在傅里叶变换后进行频域空间混合(例如【4, 24, 47】)。然而,傅里叶变换将输入完全转换为频域表示,因此无法学习相邻像素之间的局部交互。相比之下,小波变换(WT)在将图像分解到不同频段的同时成功保留了一些局部信息,从而允许我们在不同分解级别上进行操作。此外,基于傅里叶的方法往往依赖于特定大小的输入来确定权重数量,因此很难用于下游任务。一项并行工作【20】利用神经隐式函数进行频域中的高效混合。
在本节中,我们首先描述如何使用卷积进行小波变换,然后我们提出在小波域中进行卷积的解决方案,称为WTConv。我们还描述了WTConv的理论优势并分析了其计算成本。
在本工作中,我们采用Haar小波变换,因为它高效且直接【14,16,32】。然而,我们注意到我们的方法并不局限于Haar小波,虽然使用其他小波基会增加计算成本。
给定一张图像,在一个空间维度(宽度或高度)上的一级Haar小波变换是通过与核和进行深度卷积,然后应用一个标准的2倍下采样算子来实现的。为了执行二维Haar小波变换,我们在两个维度上组合该操作,使用以下四个滤波器集合以步长为2进行深度卷积:
注意,是一个低通滤波器,而、、是一组高通滤波器。对于每个输入通道,卷积的输出
有四个通道,每个通道(在每个空间维度上)的分辨率都是的一半。是的低频分量,而、、分别是其水平、垂直和对角高频分量。
由于等式1中的核构成了一个正交归一化基,因此应用逆小波变换(IWT)可以通过转置卷积获得:
然后通过递归分解低频分量来获得级联小波分解,每一级的分解由下式给出:
其中,是当前级别。这会导致低频分量的频率分辨率增加,空间分辨率降低。
如第2.2节所述,增加卷积层的核大小会使参数数量(因此是自由度)二次增加。为了缓解这个问题,我们提出以下方案。首先,使用小波变换对输入的低频和高频内容进行滤波和下采样。然后,在对不同的频率图进行小核深度卷积之后,使用逆小波变换来构建输出。换句话说,该过程由下式给出:
其中是输入张量,是一个深度卷积核的权重张量,其输入通道数是的四倍。此操作不仅分离了不同频率分量之间的卷积,还允许较小的核在原始输入的较大区域上操作,即相对于输入增加了其感受野。图2对此进行了说明。
我们采用这一级联操作,并通过使用与等式4相同的级联原理进一步扩展它。该过程由下式给出:
其中是该层的输入,表示第3.1节中描述的第级的所有三个高频图。
为了组合不同频率的输出,我们利用了小波变换及其逆是线性操作这一事实,即。因此,执行
可以将不同级别的卷积相加,其中是从第级开始的累积输出。这与【11】中的方法一致,其中将两个不同大小卷积的输出相加作为最终输出。
与文献[11]不同,我们不能对,进行单独归一化,因为它们的单独归一化并不对应于原始域的归一化。相反,我们发现仅对每个频率分量进行通道级缩放以衡量其贡献就足够了。图3展示了2级小波变换(WT)的WTConv可视化。算法详见附录A。
在给定的卷积神经网络(CNN)中引入WTConv有两个主要的技术优势。首先,WT的每一级都会增加层的感受野大小,而可训练参数的数量只会有小幅增加。也就是说,WT的级级联频率分解,加上每一级固定大小的核,使得参数数量随级数线性增长(),而感受野则呈指数增长()。
第二个好处是,WTConv层的设计比标准卷积更能捕捉低频信息。这是因为输入的低频信息经过重复的小波变换(WT)分解后,低频信息被强调,从而增加了层对低频信息的响应。这一讨论补充了已知卷积层对输入高频信息响应的分析[19,45]。通过利用多频输入上的紧凑核,WTConv层将额外的参数放置在最需要它们的地方。
除了在标准基准测试上取得改进结果外,这些技术优势还体现在与大核方法相比,网络的可扩展性提高,对损坏和分布偏移的鲁棒性增强,以及对形状而非纹理的响应增强。我们在第4.4节中实证验证了这些假设。
深度卷积的浮点运算(FLOPs)计算成本为
其中是输入通道数,是输入的空间维度,是核大小,是每个维度的步长。例如,考虑空间维度为的单通道输入。使用大小为的核进行卷积会导致 FLOPs,而使用的核大小则会导致 FLOPs。考虑WTConv的卷积集合,每个小波域卷积都在降低2倍的空间维度上进行,尽管通道数是原始输入的4倍。这导致FLOP计数为
其中是WT的级数。继续前面的输入大小为的例子,使用大小为的核(覆盖的感受野)进行3级WTConv的多频卷积,会导致 FLOPs。当然,还需要加上WT计算本身的成本。我们注意到,当使用Haar基时,WT可以以非常高的效率实现[14]。也就是说,如果使用标准卷积操作的朴素实现,WT的FLOP计数为
因为四个核的大小为,每个空间维度的步长为2,并且在每个输入通道上操作(见第3.1节)。同样,类似的分析表明,逆小波变换(IWT)与WT具有相同的FLOP计数。继续前面的例子,对于3级WT和IWT,这会增加 FLOPs,总计 FLOPs,这仍然比具有相似感受野的标准深度卷积节省了大量计算成本。
在本节中,我们在多种设置下对WTConv进行了实验。首先,在第4.1节中,我们训练并评估了带有WTConv的ConvNeXt[38]在ImageNet-1K[10]分类任务上的性能。然后,在第4.2节和第4.3节中,我们将训练好的模型作为下游任务的骨干网络。最后,在第4.4节中,我们进一步分析了WTConv对网络的贡献。
对于ImageNet-1K [10],我们使用ConvNeXt [38]作为基础架构,并将的深度卷积替换为WTConv。ConvNeXt作为ResNet的扩展,主要由四个阶段组成,阶段之间包含下采样操作。我们将这些阶段的WTConv的层级设置为,并将卷积核大小设置为,以便在输入大小为时,每一步都能实现全局感受野。我们使用了120个和300个训练周期(详见附录B)的两种训练计划。
表1显示了120个训练周期计划的结果。由于所有网络都使用相同的ConvNeXt-T基础架构,因此我们报告了深度卷积(标记为D-W)的参数数量。请注意,为了公平比较,我们只报告了使用深度卷积核分解的SLaK和VAN的结果,因为我们只比较增加感受野的效果。我们强调,WTConv在得分最高的方法中实现了最佳结果,同时参数效率也最高。此外,它用不到GFNet一半数量的参数就实现了全局感受野。
在表2中,我们展示了300个训练周期计划的结果,并将WTConvNeXt与Swin [37]和ConvNeXt [38]进行了比较。表1和表2均表明,将WTConv引入ConvNeXt可以显著提高分类精度,而参数和FLOPs的增加却很少。例如,从ConvNeXt-S升级到ConvNeXt-B增加了39M个参数和6.7 GFLOPs,换来了0.7%的精度提升,而升级到WTConvNeXt-S仅增加了4M个参数和0.1 GFLOPs,就获得了0.5%的精度提升。
我们评估了WTConvNeXt作为UperNet [65]在ADE20K [69]语义分割任务中的骨干网络。我们使用MMSegmentation [7]进行UperNet的实现、训练和评估。训练遵循ConvNeXt的确切配置,没有进行任何参数调整。我们分别为来自第4.1节的120个和300个训练周期的预训练模型使用80K和160K次迭代的训练方案,并报告了单尺度测试的平均交并比(mIoU)指数。表3展示了结果,并显示使用WTConv时,mIoU提高了0.3-0.6%。
我们还评估了WTConvNeXt作为Cascade Mask R-CNN [2]在COCO数据集 [34]上的骨干网络。我们使用MMDetection [6]进行Cascade Mask R-CNN的实现、训练和评估。训练遵循ConvNeXt的确切配置,没有进行任何参数调整。我们分别为120个和300个训练周期的预训练模型使用1x和3x微调计划,并报告了框和掩码的平均精度(AP)。结果如表4所示,我们看到了显著的改进,因为AP和AP都提高了0.6-0.7%。详细结果见附录F。
可扩展性。我们对ImageNet-50/100/200 [48,57]上的分类任务进行了小规模的可扩展性分析,这些数据集是ImageNet [10]的子集,分别包含50/100/200个类别。在此实验中,我们使用MobilenetV2 [50],并将每个深度卷积替换为RepLK [11]、GFNet [47]、FFC [4]和提出的WTConv层。我们将WTConv的卷积核大小设置为。对于RepLK,我们使用与WTConv感受野最接近的可能卷积核大小,例如,对于具有感受野的2级WTConv,我们使用的卷积核大小。GFNet和FFC是基于傅里叶的方法。GFNet的全局滤波器层每个通道需要个参数,其中是输入的空间维度,因此它高度过参数化,特别是在MobileNetV2中,前几层的输入大小为。相比之下,FFC在不同频率上使用相同的权重,因此它不像GFNet那样直接依赖于。训练参数详见附录B。
结果总结在表5中,表明在增加感受野时,WTConv的缩放性能优于RepLK。我们假设这是由于数据不足以支持RepLK层的大量可训练参数。
这也与[36]在ImageNet-1K上的发现相一致,即在RepLK中简单地增加滤波器大小会损害结果。GFNet因过度参数化而严重受损,其结果显著下降。FFC表现更好,尽管有限的频率混合损害了其结果。
鲁棒性。我们在ImageNetC/[28,43]、ImageNet-R [27]、ImageNet-A [29]和ImageNet-Sketch [60]上对分类进行了鲁棒性评估。我们报告了ImageNet-C的平均损坏误差、ImageNet-的损坏误差以及所有其他基准测试的最高1准确率。我们还评估了在损坏[42]下的COCO上的目标检测,以损坏下的平均性能和相对性能(mPC和rPC)来衡量。我们报告了使用第4.1节中的300轮训练计划训练的模型的结果,未经任何修改或微调。
表6和表7总结了结果。注意,尽管WTConvNeXt在ImageNet-1K上的准确率比ConvNeXt高出0.3-0.4%,但在大多数鲁棒性数据集上,准确率提升超过1%,甚至高达2.2%。在损坏的目标检测中也出现了类似的趋势,这可以通过对低频响应的改善来解释[33]。更多详细表格和定性示例见附录G。
形状偏见。我们使用modelvshuman基准[18]来量化形状偏见(即基于形状而非纹理做出的预测的比例)的改进。增加形状偏见与人类感知相关,因此被认为是更可取的。
图4所示的结果证实了我们的假设,即WTConv使网络更具形状偏见,将“形状”决策的比例提高了8-12%。请注意,即使较小的WTConvNeXt-T对形状的响应也比较大的ConvNeXt网络更好,尽管后者在ImageNet-1k准确率上得分更高。这很可能是由于WTConv增加了对低频的关注度,因为形状通常与低频相关,而纹理与高频相关。定量结果见附录E。
有效感受野(ERF)。我们使用[11]提供的代码评估了WTConv对ConvNeXt-T的ERF[40]的贡献。理论上,在卷积神经网络(CNN)中,ERF与[40]成正比,其中是核大小,是网络的深度。但是,由于我们在使用较小核的同时引入WT来增加感受野,因此我们假设在考虑为该层引起的感受野的大小时,该关系仍然成立。ERF的实证评估包括从调整大小至的ImageNet验证集中随机采样50张图像。然后,对于每张图像,计算每个像素对最后一层生成的特征图中心点的梯度贡献。结果如图1所示,其中高贡献像素更亮。我们注意到,尽管WTConv的参数少于RepLK和SLaK,但它具有近乎全局的ERF。
消融研究。我们进行了消融研究,以观察WTConv层的不同配置如何影响最终结果。我们按照第4.1节的描述,在ImageNet-1K上训练WTConvNeXt-T 120个周期,并采用了各种配置。首先,我们试验了不同WT级别和核大小的组合;请注意,ConvNeXt的卷积操作针对分辨率为、、、(对于输入)的输入进行,分别允许最大WT级别为5、4、3、2。其次,我们每次在卷积中仅使用高频或低频组件之一来评估它们的贡献。最后,我们使用不同的小波基训练模型。
表8显示了所有描述的配置的结果。在这里,增加级别和核大小通常是有益的。我们还发现,单独使用每个频段可以提高模型的性能;然而,同时使用两者效果更好。结果证实,Haar小波变换(WT)就足够了,尽管探索其他基可能会提高性能。我们将其留给未来的工作。
尽管WTConv层不需要很多浮点运算(FLOPs),但在现有框架中,其运行时间可能会相对较高。这是由于多次顺序操作(WT-conv-IWT)的开销可能比计算本身更昂贵。然而,我们注意到,这可以通过使用专用实现来缓解,例如,在每个级别中并行执行WT和卷积以减少内存读取,或者就地执行WT和IWT以减少内存分配。更多实现细节见附录C。
在这项工作中,我们利用小波变换引入了WTConv,它是深度卷积的即插即用替代品,能够实现更大的感受野并更好地捕获输入中的低频信息。使用WTConv,可以纯卷积的方式配置全局感受野的空间混合。我们通过实证证明,WTConv显著增加了卷积神经网络(CNN)的有效感受野,改善了CNN的形状偏见,使网络对损坏更加鲁棒,并在各种视觉任务中取得了更好的性能。
在这里插入图片描述
训练过程与[11]类似,使用2个GPU设置,带有0.9动量的SGD优化器,每个GPU的批量大小为32,输入分辨率为,权重衰减为,以及5个周期的线性预热,之后是100个周期的余弦退火。考虑到GPU的数量,初始学习率调整为0.025。实现基于[41]。 B.2 ImageNet-1K
我们遵循[11,38]的300周期训练计划,使用AdamW[39]优化器,动量为0.9,权重衰减为0.05,批量大小为4096,学习率为,20个周期的线性预热,之后是余弦退火,RandAugment[8],标签平滑[53]系数为0.1,mixup[67]的,CutMix[66]的,Random Erasing[68]的概率为25%,随机深度[31]的丢弃路径率为T/S/B变体的10%/40%/50%,以及衰减因子为0.9999的指数移动平均(EMA)。我们还提供了较短训练计划(遵循[11])的比较,其配置与上述相似,但使用了120个周期,批量大小为2048,10个周期的预热,并且没有EMA。
我们所有实验使用的实现都相当朴素,并且可以进行许多改进。例如,除了我们在第5节中描述的内容外,Haar小波目前在我们的模型中作为常规卷积实现,其中包括作为FP32的乘以1和-1的操作。然而,这些操作可以改为求和与减法,并且可以同时对所有级别执行,以更有效地读取内存。
使用朴素实现,我们在GPU预热50个批次后,使用单个RTX3090测量了ConvNeXt和WTConvNeXt在300个大小为64的批次上的吞吐量。表9显示了以每秒图像数测量的吞吐量。可以看出,即使使用最朴素且未优化的层版本,WTConvNeXt的吞吐量也达到了原始网络的66-70%。
为了证明WTConv的兼容性,我们将其整合到另外两个网络GhostNet[25]和EfficientNet[54]中。WTConv级别的数量设置为在每个阶段相对于输入大小具有全局感受野。训练过程如第B.2节所述,采用120个周期的训练计划。结果见表10。
定量形状偏见结果见表11。
表12提供了如第4.3节所述的COCO目标检测实验的详细结果。
表13和表14分别提供了ImageNet-C和ImageNet 的详细结果。图5、图6、图7和图8展示了在不同类型损坏下的目标检测定性示例。这些示例表明,随着损坏程度的加剧,WTConvNeXt丢失的细节更少。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from .util import wavelet
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):
super(WTConv2d, self).__init__()
assert in_channels == out_channels
self.in_channels = in_channels
self.wt_levels = wt_levels
self.stride = stride
self.dilation = 1
self.wt_filter, self.iwt_filter = wavelet.create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = partial(wavelet.wavelet_transform, filters = self.wt_filter)
self.iwt_function = partial(wavelet.inverse_wavelet_transform, filters = self.iwt_filter)
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1,in_channels,1,1])
self.wavelet_convs = nn.ModuleList(
[nn.Conv2d(in_channels*4, in_channels*4, kernel_size, padding='same', stride=1, dilation=1, groups=in_channels*4, bias=False) for _ in range(self.wt_levels)]
)
self.wavelet_scale = nn.ModuleList(
[_ScaleModule([1,in_channels*4,1,1], init_scale=0.1) for _ in range(self.wt_levels)]
)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride, groups=in_channels)
else:
self.do_stride = None
def forward(self, x):
x_ll_in_levels = []
x_h_in_levels = []
shapes_in_levels = []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:,:,0,:,:]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
curr_x_tag = curr_x_tag.reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:,:,0,:,:])
x_h_in_levels.append(curr_x_tag[:,:,1:4,:,:])
next_x_ll = 0
for i in range(self.wt_levels-1, -1, -1):
curr_x_ll = x_ll_in_levels.pop()
curr_x_h = x_h_in_levels.pop()
curr_shape = shapes_in_levels.pop()
curr_x_ll = curr_x_ll + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll
assert len(x_ll_in_levels) == 0
x = self.base_scale(self.base_conv(x))
x = x + x_tag
if self.do_stride is not None:
x = self.do_stride(x)
return x
class _ScaleModule(nn.Module):
def __init__(self, dims, init_scale=1.0, init_bias=0):
super(_ScaleModule, self).__init__()
self.dims = dims
self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
self.bias = None
def forward(self, x):
return torch.mul(self.weight, x)