前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >【论文复现】图像压缩算法

【论文复现】图像压缩算法

作者头像
Eternity._
发布2024-12-28 10:33:59
发布2024-12-28 10:33:59
16600
代码可运行
举报
文章被收录于专栏:登神长阶登神长阶
运行总次数:0
代码可运行

本文对CVPR2023论文《Learned Image Compression with Mixed Transformer-CNN Architecture》进行解析和复现,作者团队来自日本早稻田大学的计算机科学与通信工程实验室。本文提出的方法是目前基于深度学习的图像压缩领域性能最佳的方法。 论文下载地址"https://arxiv.org/abs/2303.14978"

文章出发点


首先,这篇文章的出发点就是图像压缩最本源的目的,就是探索如何在相同的码率下获得更高质量的重建图像,或者说在得到的重建图像质量一样的情况下,如何进一步节省码率。

然后作者就站在前人做的利用深度学习压缩的基础上思考,有一批人使用CNN的方法,可以很好地降低空间冗余度,然后捕获图像的空域结构;另一批人使用Transformer的结构,来捕捉图像中长距离的空间依赖关系。于是作者就想,能不能把这两种方法做一个结合,做这么样一个结构,使其同时具备这两种算法的优点。于是就在此基础上,作者提出了本文的

先验知识


在这一部分,我结合图文向大家解释一下基于深度学习进行图像压缩的基本框架流程,便于进一步理解本文方法。

先给出示意图如下:

首先是原图经过编码器得到一个潜在的表示y,就可以类比传统图像压缩里稀疏化的变换,只不过这里用一个可以学习的变换器来代替之前的人工设计的变换方法。 然后得到y之后,我们也是将y进行量化,熵编码,然后打包成码流进行传输。熵编码的部分呢,就是通过学习y的一些特征,来指导熵编码器对量化后的y进行更加精确的熵估计,一般是用基于高斯分布的算术编码器来进行熵编码,所以可以看到,这一部分学习的参数,往往也是均值和方差等等。

大家做的工作一般也是集中于如何改进这个编码器的结构,得到更加合理的潜在表示,然后一方面就是对熵编码器这里做一些工作,想方设法能使熵编码器对量化后的y进行更加准确地估计,从而做到节省码率。

LIC-TCM算法亮点一:TCM块


那么我们首先来看一下TCM块的设计。

在这个结构里大家可以看到作者是使用了这个Swin Transformer 块和残差块来实现的一个两个方法的融合。

具体过程:输入的特征向量经过一个11的卷积,我们知道11卷积能够很好的糅合各通道之间的信息,然后下一步就是在通道维度对这个特征向量做一个切割,分别送入到Transformer块和残差块里进行学习。采用这种并行式的处理,一方面可以减小参数量,另一方面,能够分别学习各自擅长学习的特征。然后对各自得到的结果向量,先进行一个Concatenate,然后同样经过一个1*1卷积,对其各自的特征进行一个交互。

需要注意的是作者并没有将这个向量直接作为输出,而是进行了一个双阶段的设计,我认为这个也是以Swin-Transformer为启发,可以更好地对特征进行融合。

然后整体的表达式就是下面这三个式子。

LIC-TCM算法亮点二:熵模型设计


接下来是作者的第二部分工作,提出了一种熵模型。

可以看出常规的熵模型是将整个y送入到一个提取超先验信息的网络当中,然后得到熵编码需要的参数。而这里的熵模型可以理解为将y沿着通道维度拆分成为多段,分别进行熵编码,最后再进行拼接。这么做的好处,不仅可以利用GPU的并行处理,而且还能通过上一段解码出的y进行指导,获得更加准确的估计。这个熵模型的想法是前人的工作,作者主要做的是在这样一个个参数估计的网络里引入了一种注意力机制。

基于Swin-Transformer块的注意力机制。可以看到SWAttention模块就是下图这样的一个结构,与前面TCM block的设计思路类似,作者都是想将非局部信息和局部信息做一个很好的结合,于是就在获取注意力的时候加入这样一个Swin-Transformer的基本块,来映射一些非局部的信息。

核心代码解读


这里主要介绍模型部分的代码,对于一些基本卷积操作或训练时的基础设置等不做赘述。

  • 首先是TCM模块部分
代码语言:javascript
代码运行次数:0
复制
def forward(self, x):
    conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
    conv_x = self.conv_block(conv_x) + conv_x
    trans_x = Rearrange('b c h w -> b h w c')(trans_x)
    trans_x = self.trans_block(trans_x)
    trans_x = Rearrange('b h w c -> b c h w')(trans_x)
    res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
    x = x + res
    return x

这是TCM模块的前向过程,也是本文的核心之一。可以看到与论文表述一致,通过对输入特征split操作之后,分别入到残差块和Swin-Transformer块里,再做拼接操作。

  • Swin-Transformer块的代码如下:
代码语言:javascript
代码运行次数:0
复制
def forward(self, x):
    resize = False
    if (x.size(-1) <= self.window_size) or (x.size(-2) <= self.window_size):
        padding_row = (self.window_size - x.size(-2)) // 2
        padding_col = (self.window_size - x.size(-1)) // 2
        x = F.pad(x, (padding_col, padding_col+1, padding_row, padding_row+1))
    trans_x = Rearrange('b c h w -> b h w c')(x)
    trans_x = self.block_1(trans_x)
    trans_x =  self.block_2(trans_x)
    trans_x = Rearrange('b h w c -> b c h w')(trans_x)
    if resize:
        x = F.pad(x, (-padding_col, -padding_col-1, -padding_row, -padding_row-1))
    return trans_x

其中的Block块为内嵌的Transformer块:

代码语言:javascript
代码运行次数:0
复制
def forward(self, x):
    x = x + self.drop_path(self.msa(self.ln1(x)))
    x = x + self.drop_path(self.mlp(self.ln2(x)))
    return x
  • 最后整个模型的结构如下:

其中,熵模型(即右侧红框部分)采用基于通道划分和窗注意力机制,其训练过程中的前向代码如下,估计出高斯建模的均值和方差。

代码语言:javascript
代码运行次数:0
复制
for slice_index, y_slice in enumerate(y_slices):
    support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
    mean_support = torch.cat([latent_means] + support_slices, dim=1)
    mean_support = self.atten_mean[slice_index](mean_support)
    mu = self.cc_mean_transforms[slice_index](mean_support)
    mu = mu[:, :, :y_shape[0], :y_shape[1]]
    mu_list.append(mu)
    scale_support = torch.cat([latent_scales] + support_slices, dim=1)
    scale_support = self.atten_scale[slice_index](scale_support)
    scale = self.cc_scale_transforms[slice_index](scale_support)
    scale = scale[:, :, :y_shape[0], :y_shape[1]]
    scale_list.append(scale)
    _, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
    y_likelihood.append(y_slice_likelihood)
    y_hat_slice = ste_round(y_slice - mu) + mu
    # if self.training:
    #     lrp_support = torch.cat([mean_support + torch.randn(mean_support.size()).cuda().mul(scale_support), y_hat_slice], dim=1)
    # else:
    lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
    lrp = self.lrp_transforms[slice_index](lrp_support)
    lrp = 0.5 * torch.tanh(lrp)
    y_hat_slice += lrp

    y_hat_slices.append(y_hat_slice)
  • 编码器部分则采用算术编码器,具体调用代码如下:
代码语言:javascript
代码运行次数:0
复制
encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
y_string = encoder.flush()

以上内容即为LIC-TCM模型的核心代码讲解,下面介绍具体的复现流程。

复现流程


首先,本文也是基于CompressAI的API进行开发和训练的,因此,在此之前需要先安装Compressai库,具体的安装方式可以参考我的另一篇博文《STF—顶会图像压缩方法》,此处不再赘述。

复现时主要分为训练和测试评估两个方面。

其中训练部分的设置如下:

代码语言:javascript
代码运行次数:0
复制
CUDA_VISIBLE_DEVICES='0' python -u ./train.py -d [path of training dataset] \
    --cuda --N 128 --lambda 0.05 --epochs 50 --lr_epoch 45 48 \
    --save_path [path for checkpoint]
    --checkpoint [path of the pretrained checkpoint]

其中,

  • "CUDA_VISIBLE_DEVICES"选用可调用的显卡序号,注意这里支持多显卡调用;
  • "d"为训练所用数据集,这里可以根据自己的任务选择合适的数据集,附件中提供了DIV2K的下载地址;
  • "–cuda"表示使用GPU进行训练;
  • "–lambda"即为率失真的比例系数;
  • "epochs"为训练的轮数;
  • "–lr_epoch"表示milestone,即在训练至多少轮的时候,学习率降为10%;
  • "–save_path"为模型的保存路径;
  • "–checkpoint"为预训练的模型参数

这些参数均有默认值,可以参考如下参数定义的代码:

代码语言:javascript
代码运行次数:0
复制
def parse_args(argv):
    parser = argparse.ArgumentParser(description="Example training script.")
    parser.add_argument(
        "-m",
        "--model",
        default="bmshj2018-factorized",
        choices=models.keys(),
        help="Model architecture (default: %(default)s)",
    )
    parser.add_argument(
        "-d", "--dataset", type=str, required=True, help="Training dataset"
    )
    parser.add_argument(
        "-e",
        "--epochs",
        default=50,
        type=int,
        help="Number of epochs (default: %(default)s)",
    )
    parser.add_argument(
        "-lr",
        "--learning-rate",
        default=1e-4,
        type=float,
        help="Learning rate (default: %(default)s)",
    )
    parser.add_argument(
        "-n",
        "--num-workers",
        type=int,
        default=20,
        help="Dataloaders threads (default: %(default)s)",
    )
    parser.add_argument(
        "--lambda",
        dest="lmbda",
        type=float,
        default=3,
        help="Bit-rate distortion parameter (default: %(default)s)",
    )
    parser.add_argument(
        "--batch-size", type=int, default=8, help="Batch size (default: %(default)s)"
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=8,
        help="Test batch size (default: %(default)s)",
    )
    parser.add_argument(
        "--aux-learning-rate",
        default=1e-3,
        help="Auxiliary loss learning rate (default: %(default)s)",
    )
    parser.add_argument(
        "--patch-size",
        type=int,
        nargs=2,
        default=(256, 256),
        help="Size of the patches to be cropped (default: %(default)s)",
    )
    parser.add_argument("--cuda", action="store_true", help="Use cuda")
    parser.add_argument(
        "--save", action="store_true", default=True, help="Save model to disk"
    )
    parser.add_argument(
        "--seed", type=float, default=100, help="Set random seed for reproducibility"
    )
    parser.add_argument(
        "--clip_max_norm",
        default=1.0,
        type=float,
        help="gradient clipping max norm (default: %(default)s",
    )
    parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
    parser.add_argument("--type", type=str, default='mse', help="loss type", choices=['mse', "ms-ssim"])
    parser.add_argument("--save_path", type=str, help="save_path")
    parser.add_argument(
        "--skip_epoch", type=int, default=0
    )
    parser.add_argument(
        "--N", type=int, default=128,
    )
    parser.add_argument(
        "--lr_epoch", nargs='+', type=int
    )
    parser.add_argument(
        "--continue_train", action="store_true", default=True
    )
    args = parser.parse_args(argv)
    return args

这是训练轮数设置为1000轮,学习率设置为1e-4时的训练过程截图,

其中MSE Loss为训练过程中的均方差损失,BPP Loss为率失真损失,Aux Loss为熵模型的辅助损失,BCE Loss为个人创新点设计的损失,感兴趣的同学可以私信与本人交流(算一个创新点,并且已经有一些初步的实验结果)。

而测试部分的代码逻辑就较为简单:

代码语言:javascript
代码运行次数:0
复制
python eval.py --checkpoint [path of the pretrained checkpoint] --data [path of testing dataset] --cuda
  • "checkpoint"即为测试所用的模型;
  • "–data"为要测试的数据

部署方式


核心的部署需求即为PyTorch环境和Compressai的API。

附件部分提供了数据集(DIV2K和Kodak)下载地址,以及基于MSE损失优化的预训练模型下载地址。


编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-12-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章出发点
  • 先验知识
  • LIC-TCM算法亮点一:TCM块
  • LIC-TCM算法亮点二:熵模型设计
  • 核心代码解读
  • 复现流程
  • 部署方式
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档