本文对CVPR2023论文《Learned Image Compression with Mixed Transformer-CNN Architecture》进行解析和复现,作者团队来自日本早稻田大学的计算机科学与通信工程实验室。本文提出的方法是目前基于深度学习的图像压缩领域性能最佳的方法。 论文下载地址"https://arxiv.org/abs/2303.14978"
首先,这篇文章的出发点就是图像压缩最本源的目的,就是探索如何在相同的码率下获得更高质量的重建图像,或者说在得到的重建图像质量一样的情况下,如何进一步节省码率。
然后作者就站在前人做的利用深度学习压缩的基础上思考,有一批人使用CNN的方法,可以很好地降低空间冗余度,然后捕获图像的空域结构;另一批人使用Transformer的结构,来捕捉图像中长距离的空间依赖关系。于是作者就想,能不能把这两种方法做一个结合,做这么样一个结构,使其同时具备这两种算法的优点。于是就在此基础上,作者提出了本文的
在这一部分,我结合图文向大家解释一下基于深度学习进行图像压缩的基本框架流程,便于进一步理解本文方法。
先给出示意图如下:
首先是原图经过编码器得到一个潜在的表示y,就可以类比传统图像压缩里稀疏化的变换,只不过这里用一个可以学习的变换器来代替之前的人工设计的变换方法。 然后得到y之后,我们也是将y进行量化,熵编码,然后打包成码流进行传输。熵编码的部分呢,就是通过学习y的一些特征,来指导熵编码器对量化后的y进行更加精确的熵估计,一般是用基于高斯分布的算术编码器来进行熵编码,所以可以看到,这一部分学习的参数,往往也是均值和方差等等。
大家做的工作一般也是集中于如何改进这个编码器的结构,得到更加合理的潜在表示,然后一方面就是对熵编码器这里做一些工作,想方设法能使熵编码器对量化后的y进行更加准确地估计,从而做到节省码率。
那么我们首先来看一下TCM块的设计。
在这个结构里大家可以看到作者是使用了这个Swin Transformer 块和残差块来实现的一个两个方法的融合。
具体过程:输入的特征向量经过一个11的卷积,我们知道11卷积能够很好的糅合各通道之间的信息,然后下一步就是在通道维度对这个特征向量做一个切割,分别送入到Transformer块和残差块里进行学习。采用这种并行式的处理,一方面可以减小参数量,另一方面,能够分别学习各自擅长学习的特征。然后对各自得到的结果向量,先进行一个Concatenate,然后同样经过一个1*1卷积,对其各自的特征进行一个交互。
需要注意的是作者并没有将这个向量直接作为输出,而是进行了一个双阶段的设计,我认为这个也是以Swin-Transformer为启发,可以更好地对特征进行融合。
然后整体的表达式就是下面这三个式子。
接下来是作者的第二部分工作,提出了一种熵模型。
可以看出常规的熵模型是将整个y送入到一个提取超先验信息的网络当中,然后得到熵编码需要的参数。而这里的熵模型可以理解为将y沿着通道维度拆分成为多段,分别进行熵编码,最后再进行拼接。这么做的好处,不仅可以利用GPU的并行处理,而且还能通过上一段解码出的y进行指导,获得更加准确的估计。这个熵模型的想法是前人的工作,作者主要做的是在这样一个个参数估计的网络里引入了一种注意力机制。
基于Swin-Transformer块的注意力机制。可以看到SWAttention模块就是下图这样的一个结构,与前面TCM block的设计思路类似,作者都是想将非局部信息和局部信息做一个很好的结合,于是就在获取注意力的时候加入这样一个Swin-Transformer的基本块,来映射一些非局部的信息。
这里主要介绍模型部分的代码,对于一些基本卷积操作或训练时的基础设置等不做赘述。
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
这是TCM模块的前向过程,也是本文的核心之一。可以看到与论文表述一致,通过对输入特征split操作之后,分别入到残差块和Swin-Transformer块里,再做拼接操作。
def forward(self, x):
resize = False
if (x.size(-1) <= self.window_size) or (x.size(-2) <= self.window_size):
padding_row = (self.window_size - x.size(-2)) // 2
padding_col = (self.window_size - x.size(-1)) // 2
x = F.pad(x, (padding_col, padding_col+1, padding_row, padding_row+1))
trans_x = Rearrange('b c h w -> b h w c')(x)
trans_x = self.block_1(trans_x)
trans_x = self.block_2(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
if resize:
x = F.pad(x, (-padding_col, -padding_col-1, -padding_row, -padding_row-1))
return trans_x
其中的Block块为内嵌的Transformer块:
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
其中,熵模型(即右侧红框部分)采用基于通道划分和窗注意力机制,其训练过程中的前向代码如下,估计出高斯建模的均值和方差。
for slice_index, y_slice in enumerate(y_slices):
support_slices = (y_hat_slices if self.max_support_slices < 0 else y_hat_slices[:self.max_support_slices])
mean_support = torch.cat([latent_means] + support_slices, dim=1)
mean_support = self.atten_mean[slice_index](mean_support)
mu = self.cc_mean_transforms[slice_index](mean_support)
mu = mu[:, :, :y_shape[0], :y_shape[1]]
mu_list.append(mu)
scale_support = torch.cat([latent_scales] + support_slices, dim=1)
scale_support = self.atten_scale[slice_index](scale_support)
scale = self.cc_scale_transforms[slice_index](scale_support)
scale = scale[:, :, :y_shape[0], :y_shape[1]]
scale_list.append(scale)
_, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu)
y_likelihood.append(y_slice_likelihood)
y_hat_slice = ste_round(y_slice - mu) + mu
# if self.training:
# lrp_support = torch.cat([mean_support + torch.randn(mean_support.size()).cuda().mul(scale_support), y_hat_slice], dim=1)
# else:
lrp_support = torch.cat([mean_support, y_hat_slice], dim=1)
lrp = self.lrp_transforms[slice_index](lrp_support)
lrp = 0.5 * torch.tanh(lrp)
y_hat_slice += lrp
y_hat_slices.append(y_hat_slice)
encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
y_string = encoder.flush()
以上内容即为LIC-TCM模型的核心代码讲解,下面介绍具体的复现流程。
首先,本文也是基于CompressAI的API进行开发和训练的,因此,在此之前需要先安装Compressai库,具体的安装方式可以参考我的另一篇博文《STF—顶会图像压缩方法》,此处不再赘述。
复现时主要分为训练和测试评估两个方面。
其中训练部分的设置如下:
CUDA_VISIBLE_DEVICES='0' python -u ./train.py -d [path of training dataset] \
--cuda --N 128 --lambda 0.05 --epochs 50 --lr_epoch 45 48 \
--save_path [path for checkpoint]
--checkpoint [path of the pretrained checkpoint]
其中,
这些参数均有默认值,可以参考如下参数定义的代码:
def parse_args(argv):
parser = argparse.ArgumentParser(description="Example training script.")
parser.add_argument(
"-m",
"--model",
default="bmshj2018-factorized",
choices=models.keys(),
help="Model architecture (default: %(default)s)",
)
parser.add_argument(
"-d", "--dataset", type=str, required=True, help="Training dataset"
)
parser.add_argument(
"-e",
"--epochs",
default=50,
type=int,
help="Number of epochs (default: %(default)s)",
)
parser.add_argument(
"-lr",
"--learning-rate",
default=1e-4,
type=float,
help="Learning rate (default: %(default)s)",
)
parser.add_argument(
"-n",
"--num-workers",
type=int,
default=20,
help="Dataloaders threads (default: %(default)s)",
)
parser.add_argument(
"--lambda",
dest="lmbda",
type=float,
default=3,
help="Bit-rate distortion parameter (default: %(default)s)",
)
parser.add_argument(
"--batch-size", type=int, default=8, help="Batch size (default: %(default)s)"
)
parser.add_argument(
"--test-batch-size",
type=int,
default=8,
help="Test batch size (default: %(default)s)",
)
parser.add_argument(
"--aux-learning-rate",
default=1e-3,
help="Auxiliary loss learning rate (default: %(default)s)",
)
parser.add_argument(
"--patch-size",
type=int,
nargs=2,
default=(256, 256),
help="Size of the patches to be cropped (default: %(default)s)",
)
parser.add_argument("--cuda", action="store_true", help="Use cuda")
parser.add_argument(
"--save", action="store_true", default=True, help="Save model to disk"
)
parser.add_argument(
"--seed", type=float, default=100, help="Set random seed for reproducibility"
)
parser.add_argument(
"--clip_max_norm",
default=1.0,
type=float,
help="gradient clipping max norm (default: %(default)s",
)
parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
parser.add_argument("--type", type=str, default='mse', help="loss type", choices=['mse', "ms-ssim"])
parser.add_argument("--save_path", type=str, help="save_path")
parser.add_argument(
"--skip_epoch", type=int, default=0
)
parser.add_argument(
"--N", type=int, default=128,
)
parser.add_argument(
"--lr_epoch", nargs='+', type=int
)
parser.add_argument(
"--continue_train", action="store_true", default=True
)
args = parser.parse_args(argv)
return args
这是训练轮数设置为1000轮,学习率设置为1e-4时的训练过程截图,
其中MSE Loss为训练过程中的均方差损失,BPP Loss为率失真损失,Aux Loss为熵模型的辅助损失,BCE Loss为个人创新点设计的损失,感兴趣的同学可以私信与本人交流(算一个创新点,并且已经有一些初步的实验结果)。
而测试部分的代码逻辑就较为简单:
python eval.py --checkpoint [path of the pretrained checkpoint] --data [path of testing dataset] --cuda
核心的部署需求即为PyTorch环境和Compressai的API。
附件部分提供了数据集(DIV2K和Kodak)下载地址,以及基于MSE损失优化的预训练模型下载地址。
编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!