前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >专栏 >Entroformer图像编码

Entroformer图像编码

作者头像
Srlua
发布2024-12-20 09:47:35
发布2024-12-20 09:47:35
12200
代码可运行
举报
文章被收录于专栏:CSDN社区搬运CSDN社区搬运
运行总次数:0
代码可运行

本文选自ICLR2022的论文"Entroformer: A Transformer-based Entropy Model for Learned Image Compression",源码可在附件中查看

总体介绍

本文介绍了Entroformer,一种基于Transformer的熵模型,用于深度学习图像压缩。与传统的基于卷积神经网络的熵模型不同,Entroformer利用Transformer的自注意力机制有效捕捉全局依赖性,并在图像压缩中实现了高效的概率分布估计。此外,本文提出了一个并行双向上下文模型,加速了解码过程。实验表明,Entroformer在图像压缩任务中表现优异,同时具有较高的时间效率。

背景

图像压缩是计算机视觉中的基础研究领域,随着深度学习的发展,学习方法在这一任务中取得了多项突破。当前最先进的深度图像压缩模型建立在自编码器框架上,使用熵模型估计潜在表示的条件概率分布。本文旨在提高熵模型的预测能力,从而在不增加失真的情况下提高压缩率。

提出方法

超先验与上下文模型的压缩

熵模型有多种类型,如超先验模型、上下文模型和组合方法。超先验方法通常利用量化潜在表示的附加信息,而上下文模型则通过自回归先验结合潜在的因果上下文进行预测。Entroformer结合了基于超先验和上下文模型的方法,提高了压缩效率。

图1 (b)即为(a)中熵模型使用Transformer的改进

Transformer-based 熵模型

Transformer架构

Entroformer采用Transformer的编码器-解码器结构,使用自注意力机制关联不同位置的潜在表示,捕捉图像的空间和内容依赖性。此外,本文提出了多头注意力和top-k选择机制,以提取精确的概率分布依赖关系。

位置编码

为了在图像压缩中提供更好的空间表示,Entroformer设计了一种基于相对位置编码的单位,该编码采用菱形边界,嵌入了图像压缩中的距离先验知识。

图2 相对位置编码

Top-k 自注意力机制

在自注意力机制中,Entroformer选择top-k最相似的关键点来计算注意力矩阵,从而减少不相关信息的干扰,稳定训练过程。

并行双向上下文模型

通过双向上下文模型,Entroformer在保持性能的同时,加速了解码过程。与传统的单向上下文模型相比,双向模型引入了未来上下文,提高了预测的准确性。

图3 双向上下文模型

实验结果

实验通过计算率-失真性能(RD)评估了Entroformer的效果。结果表明,Entroformer在低比特率下比最先进的卷积神经网络方法提高了5.2%,比标准编解码器BPG提高了20.5%。此外,双向上下文模型在性能和速度之间实现了良好的平衡。

代码复现

首先核心的代码逻辑如下

代码语言:javascript
代码运行次数:0
运行
复制
class TransDecoder2(TransDecoder):
    train_scan_mode = 'default'  #  'random', 'default'
    test_scan_mode = 'checkboard'
    is_decoder = False
    def __init__(self, cin=0, cout=0, opt=None):
        super().__init__(cin, cout, opt)
        del self.sos_pred_token

    def forward(self, x, manual_mask=None):
        x = x.clone()
        batch_size, channels, height, width  = x.shape   # input_shape

        # Self-attention Mask & Token Mask
        if manual_mask is None:
            mask, token_mask, input_mask, output_mask = self.get_mask(batch_size, height, width)
        else:
            mask, token_mask, input_mask, output_mask = manual_mask
        mask, input_mask, output_mask = mask.to(x.device), input_mask.to(x.device), output_mask.to(x.device)
        token_mask = token_mask.to(x.device) if token_mask is not None else token_mask

        # Mask Input
        x.masked_fill_(~input_mask, 0.)
        # Input Embedding
        x = rearrange(x, 'b c h w -> b (h w) c')
        inputs_embeds = self.to_patch_embedding(x)

        # Init state
        position_bias = None
        # encoder_decoder_position_bias = None
        hidden_states = inputs_embeds

        topk = self.attn_topk
        if self.training and topk != -1:
            topk = np.random.randint(topk//2, topk*2)

        for _, layer_module in enumerate(self.blocks):
            # Transformer block
            layer_outputs = layer_module(
                hidden_states,
                shape_2d=[height, width],
                attention_mask=mask,
                position_bias=position_bias,
                topk=topk,
            )

            hidden_states = layer_outputs[0]

            # We share the position biases between the layers - the first layer store them
            # layer_outputs = hidden-states, (self-attention position bias), (cross-attention position bias)
            if self.rpe_shared:
                position_bias = layer_outputs[1]

        # Out projection
        out = self.mlp_head(hidden_states)        
        # Reshape Output to 2D map
        out = rearrange(out, 'b (h w) c -> b c h w', h=height)
        # Mask output
        out.masked_fill_(~output_mask, 0.)
        return out            

    def get_mask(self, b, h, w):
        n = h*w
        if self.training:
            if(self.train_scan_mode == 'random' and hasattr(self, 'sampler')):
                #mask = torch.ones(n, n).bool()    # modified
                token_mask = None
                mask_random = (self.sampler.sample([n]) > self.mask_ratio).bool()
                input_mask = mask_random.clone().view(h,w)
                output_mask = ~mask_random.clone().view(h,w)
                mask = repeat(mask_random.unsqueeze(0), '() n -> d n', d=n)
                mask = mask | torch.eye(n).bool()
            else:
                #mask = torch.ones(n, n).bool()
                token_mask = None
                mask_checkboard = torch.ones((h, w)).bool()
                mask_checkboard[0::2, 0::2] = 0
                mask_checkboard[1::2, 1::2] = 0
                input_mask = mask_checkboard.clone()
                output_mask = ~mask_checkboard.clone()
                mask = repeat(mask_checkboard.view(1,-1), '() n -> d n', d=n)
                mask = mask | torch.eye(n).bool()
        else:
            if 'checkboard' in self.test_scan_mode:
                #mask = torch.ones(n, n).bool()
                token_mask = None
                mask_checkboard = torch.ones((h, w)).bool()
                if self.test_scan_mode == 'checkboard':
                    mask_checkboard[0::2, 0::2] = 0
                    mask_checkboard[1::2, 1::2] = 0
                else:
                    mask_checkboard[0::2, 1::2] = 0
                    mask_checkboard[1::2, 0::2] = 0
                input_mask = mask_checkboard.clone()
                output_mask = ~mask_checkboard.clone()
                mask = repeat(mask_checkboard.view(1,-1), '() n -> d n', d=n)
                mask = mask | torch.eye(n).bool()
            else:
                raise ValueError("No such test scan mode.")

        #print(input_mask)
        mask = repeat(mask.unsqueeze(0), '() d n -> b d n', b=b)
        token_mask = token_mask  # torch.ones_like(mask).bool()
        input_mask = repeat(input_mask.unsqueeze(0).unsqueeze(0), '() () h w -> b d h w', b=b, d=self.cin)
        channel = self.dim if self.cout == 0 else self.cout
        output_mask = repeat(output_mask.unsqueeze(0).unsqueeze(0), '() () h w -> b d h w', b=b, d=channel)

        return mask, token_mask, input_mask, output_mask

代码解释如下:

  • 前向传播方法forward,接收输入x和可选的手动遮罩manual_mask。
  • 根据输入的形状创建遮罩,并将输入转换为嵌入表示。
  • 通过多个Transformer块进行处理,并共享位置偏置。
  • 最后,输出通过MLP头部投影并重塑为2D地图。
  • 支持随机遮罩和棋盘格遮罩模式,并根据输入的高度和宽度生成对应的输入和输出遮罩。

代码部署及使用

Train

代码语言:javascript
代码运行次数:0
运行
复制
sh scripts/pretrain.sh 0.3
sh scripts/train.sh [tradeoff_lambda(e.g. 0.02)]
(You may use your own dataset by modifying the train/test data path.)

Evaluate

代码语言:javascript
代码运行次数:0
运行
复制
# Kodak
sh scripts/test.sh [/path/to/kodak] [model_path]
(sh test_parallel.sh [/path/to/kodak] [model_path])

Compress

代码语言:javascript
代码运行次数:0
运行
复制
sh scripts/compress.sh [original.png] [model_path]
(sh compress_parallel.sh [original.png] [model_path])

Decompress

代码语言:javascript
代码运行次数:0
运行
复制
sh scripts/decompress.sh [original.bin] [model_path]
(sh decompress_parallel.sh [original.bin] [model_path])

本人测试结果

使用几种经典的图像压缩算法在lena图上做了测试,测试结果如下图所示:

​​

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-12-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 总体介绍
  • 背景
  • 提出方法
    • 超先验与上下文模型的压缩
    • Transformer-based 熵模型
      • Transformer架构
      • 位置编码
      • Top-k 自注意力机制
      • 并行双向上下文模型
    • 实验结果
  • 代码复现
  • 代码部署及使用
    • Train
    • Evaluate
    • Compress
    • Decompress
  • 本人测试结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档