本文选自ICLR2022的论文"Entroformer: A Transformer-based Entropy Model for Learned Image Compression",源码可在附件中查看
本文介绍了Entroformer,一种基于Transformer的熵模型,用于深度学习图像压缩。与传统的基于卷积神经网络的熵模型不同,Entroformer利用Transformer的自注意力机制有效捕捉全局依赖性,并在图像压缩中实现了高效的概率分布估计。此外,本文提出了一个并行双向上下文模型,加速了解码过程。实验表明,Entroformer在图像压缩任务中表现优异,同时具有较高的时间效率。
图像压缩是计算机视觉中的基础研究领域,随着深度学习的发展,学习方法在这一任务中取得了多项突破。当前最先进的深度图像压缩模型建立在自编码器框架上,使用熵模型估计潜在表示的条件概率分布。本文旨在提高熵模型的预测能力,从而在不增加失真的情况下提高压缩率。
熵模型有多种类型,如超先验模型、上下文模型和组合方法。超先验方法通常利用量化潜在表示的附加信息,而上下文模型则通过自回归先验结合潜在的因果上下文进行预测。Entroformer结合了基于超先验和上下文模型的方法,提高了压缩效率。
图1 (b)即为(a)中熵模型使用Transformer的改进
Entroformer采用Transformer的编码器-解码器结构,使用自注意力机制关联不同位置的潜在表示,捕捉图像的空间和内容依赖性。此外,本文提出了多头注意力和top-k选择机制,以提取精确的概率分布依赖关系。
为了在图像压缩中提供更好的空间表示,Entroformer设计了一种基于相对位置编码的单位,该编码采用菱形边界,嵌入了图像压缩中的距离先验知识。
图2 相对位置编码
在自注意力机制中,Entroformer选择top-k最相似的关键点来计算注意力矩阵,从而减少不相关信息的干扰,稳定训练过程。
通过双向上下文模型,Entroformer在保持性能的同时,加速了解码过程。与传统的单向上下文模型相比,双向模型引入了未来上下文,提高了预测的准确性。
图3 双向上下文模型
实验通过计算率-失真性能(RD)评估了Entroformer的效果。结果表明,Entroformer在低比特率下比最先进的卷积神经网络方法提高了5.2%,比标准编解码器BPG提高了20.5%。此外,双向上下文模型在性能和速度之间实现了良好的平衡。
首先核心的代码逻辑如下
class TransDecoder2(TransDecoder):
train_scan_mode = 'default' # 'random', 'default'
test_scan_mode = 'checkboard'
is_decoder = False
def __init__(self, cin=0, cout=0, opt=None):
super().__init__(cin, cout, opt)
del self.sos_pred_token
def forward(self, x, manual_mask=None):
x = x.clone()
batch_size, channels, height, width = x.shape # input_shape
# Self-attention Mask & Token Mask
if manual_mask is None:
mask, token_mask, input_mask, output_mask = self.get_mask(batch_size, height, width)
else:
mask, token_mask, input_mask, output_mask = manual_mask
mask, input_mask, output_mask = mask.to(x.device), input_mask.to(x.device), output_mask.to(x.device)
token_mask = token_mask.to(x.device) if token_mask is not None else token_mask
# Mask Input
x.masked_fill_(~input_mask, 0.)
# Input Embedding
x = rearrange(x, 'b c h w -> b (h w) c')
inputs_embeds = self.to_patch_embedding(x)
# Init state
position_bias = None
# encoder_decoder_position_bias = None
hidden_states = inputs_embeds
topk = self.attn_topk
if self.training and topk != -1:
topk = np.random.randint(topk//2, topk*2)
for _, layer_module in enumerate(self.blocks):
# Transformer block
layer_outputs = layer_module(
hidden_states,
shape_2d=[height, width],
attention_mask=mask,
position_bias=position_bias,
topk=topk,
)
hidden_states = layer_outputs[0]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention position bias), (cross-attention position bias)
if self.rpe_shared:
position_bias = layer_outputs[1]
# Out projection
out = self.mlp_head(hidden_states)
# Reshape Output to 2D map
out = rearrange(out, 'b (h w) c -> b c h w', h=height)
# Mask output
out.masked_fill_(~output_mask, 0.)
return out
def get_mask(self, b, h, w):
n = h*w
if self.training:
if(self.train_scan_mode == 'random' and hasattr(self, 'sampler')):
#mask = torch.ones(n, n).bool() # modified
token_mask = None
mask_random = (self.sampler.sample([n]) > self.mask_ratio).bool()
input_mask = mask_random.clone().view(h,w)
output_mask = ~mask_random.clone().view(h,w)
mask = repeat(mask_random.unsqueeze(0), '() n -> d n', d=n)
mask = mask | torch.eye(n).bool()
else:
#mask = torch.ones(n, n).bool()
token_mask = None
mask_checkboard = torch.ones((h, w)).bool()
mask_checkboard[0::2, 0::2] = 0
mask_checkboard[1::2, 1::2] = 0
input_mask = mask_checkboard.clone()
output_mask = ~mask_checkboard.clone()
mask = repeat(mask_checkboard.view(1,-1), '() n -> d n', d=n)
mask = mask | torch.eye(n).bool()
else:
if 'checkboard' in self.test_scan_mode:
#mask = torch.ones(n, n).bool()
token_mask = None
mask_checkboard = torch.ones((h, w)).bool()
if self.test_scan_mode == 'checkboard':
mask_checkboard[0::2, 0::2] = 0
mask_checkboard[1::2, 1::2] = 0
else:
mask_checkboard[0::2, 1::2] = 0
mask_checkboard[1::2, 0::2] = 0
input_mask = mask_checkboard.clone()
output_mask = ~mask_checkboard.clone()
mask = repeat(mask_checkboard.view(1,-1), '() n -> d n', d=n)
mask = mask | torch.eye(n).bool()
else:
raise ValueError("No such test scan mode.")
#print(input_mask)
mask = repeat(mask.unsqueeze(0), '() d n -> b d n', b=b)
token_mask = token_mask # torch.ones_like(mask).bool()
input_mask = repeat(input_mask.unsqueeze(0).unsqueeze(0), '() () h w -> b d h w', b=b, d=self.cin)
channel = self.dim if self.cout == 0 else self.cout
output_mask = repeat(output_mask.unsqueeze(0).unsqueeze(0), '() () h w -> b d h w', b=b, d=channel)
return mask, token_mask, input_mask, output_mask
代码解释如下:
sh scripts/pretrain.sh 0.3
sh scripts/train.sh [tradeoff_lambda(e.g. 0.02)]
(You may use your own dataset by modifying the train/test data path.)
# Kodak
sh scripts/test.sh [/path/to/kodak] [model_path]
(sh test_parallel.sh [/path/to/kodak] [model_path])
sh scripts/compress.sh [original.png] [model_path]
(sh compress_parallel.sh [original.png] [model_path])
sh scripts/decompress.sh [original.bin] [model_path]
(sh decompress_parallel.sh [original.bin] [model_path])
使用几种经典的图像压缩算法在lena图上做了测试,测试结果如下图所示: