首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >TeaCache:让扩散模型少算几步,但尽量不掉画质

TeaCache:让扩散模型少算几步,但尽量不掉画质

作者头像
Michael阿明
发布2026-05-19 16:39:22
发布2026-05-19 16:39:22
1160
举报

文章目录
  • 1. 为什么 TeaCache 能加速?
  • 2. TeaCache 判断“能不能跳过”的依据
  • 3. 普通推理 vs TeaCache 推理
  • 4. 最关键参数:`rel_l1_thresh`
  • 5. 适用场景
  • 6. vLLM-Omni 正式用法
    • 安装环境
    • 代码
    • 如何调参?
  • 7. 工程踩坑清单
    • 坑 1:把 TeaCache 当成 LLM KV Cache
    • 坑 2:`rel_l1_thresh` 太大
    • 坑 3:步数太少时加速不明显
    • 坑 4:coefficients 不能乱迁移
    • 坑 5:服务化时 cache state 必须按请求隔离
  • 8. 总结

扩散模型生成图片/视频时,本质是在很多个 denoising step 中反复调用 Transformer/DiT。TeaCache 的核心思想很简单:

❝如果当前 step 和上一次完整计算的 step 足够相似,就不重新完整跑 Transformer,而是复用上一次缓存的 residual / 输出近似。

它不是 LLM 的 KV Cache,也不是缓存最终图片,而是缓存扩散去噪过程中的中间计算结果。 TeaCache 论文将其称为 Timestep Embedding Aware Cache,即利用 timestep embedding 估计不同 step 之间模型输出变化,从而决定是否缓存和复用。论文报告在 Open-Sora-Plan 上最高获得 4.41x 加速,VBench 质量分数仅下降 0.07%


1. 为什么 TeaCache 能加速?

扩散模型每一步大概都在做:

代码语言:javascript
复制
latent_t + timestep_t + prompt_condition
        ↓
DiT / Transformer
        ↓
预测噪声 / velocity / residual
        ↓
scheduler 更新 latent

最耗时的通常是中间的 DiT / Transformer blocks,包括 Attention、MLP、Norm、Residual 等。

普通推理是:

代码语言:javascript
复制
step 50:完整计算 Transformer
step 49:完整计算 Transformer
step 48:完整计算 Transformer
...
step 1 :完整计算 Transformer

TeaCache 的做法是:

代码语言:javascript
复制
step 50:完整计算 Transformer,缓存 residual
step 49:判断变化小,复用 step 50 的 residual
step 48:判断变化小,继续复用
step 47:变化累计变大,重新完整计算并更新缓存
在这里插入图片描述
在这里插入图片描述

vLLM-Omni 官方文档也将 TeaCache 描述为:当连续 timestep 足够相似时缓存 Transformer 计算,从而实现约 1.5x–2.0x 加速,并通过输入相似性动态判断是否复用缓存。


2. TeaCache 判断“能不能跳过”的依据

TeaCache 不会直接比较完整模型输出,因为如果已经完整跑了一次模型,那就没有加速意义了。

它使用一个更便宜的代理量:

代码语言:javascript
复制
timestep embedding 调制后的 noisy input

然后比较当前 step 和上一次完整计算 step 的差异:

代码语言:javascript
复制
rel_l1 = mean(abs(current_modulated_input - previous_modulated_input)) \
         / mean(abs(previous_modulated_input))

再通过模型相关的多项式系数做 rescale,估计真实输出差异。

如果累计差异低于阈值,就复用缓存;如果超过阈值,就重新完整计算。

TeaCache 论文明确指出,它不直接使用耗时的模型输出差异,而是利用与模型输出强相关、但计算成本很低的模型输入差异来判断缓存时机。


3. 普通推理 vs TeaCache 推理

在这里插入图片描述
在这里插入图片描述

TeaCache 省掉的主要是:

代码语言:javascript
复制
部分 denoising step 里的 Transformer 主体计算

它通常不会省掉

代码语言:javascript
复制
scheduler 更新
VAE decode
text encoder
prompt 编码
CPU/GPU 数据搬运

所以如果你的端到端瓶颈主要在 VAE、CPU offload 或 IO,那么 TeaCache 的实际加速会低于理论加速。


4. 最关键参数:rel_l1_thresh

TeaCache 最重要的参数是:

代码语言:javascript
复制
rel_l1_thresh

它控制缓存复用的激进程度:

代码语言:javascript
复制
阈值越小:更保守,完整计算更多,质量更稳,速度提升较小
阈值越大:更激进,缓存复用更多,速度更快,质量风险更高

vLLM-Omni 文档中 rel_l1_thresh 默认值是 0.2,建议范围是 0.1–0.8

低值优先质量,高值优先速度。

在这里插入图片描述
在这里插入图片描述

建议初始设置:

代码语言:javascript
复制
质量优先:0.10 ~ 0.20
均衡配置:0.20 ~ 0.40
速度优先:0.50 ~ 0.80

生产环境不要一上来拉到 0.8。更稳的方式是:

代码语言:javascript
复制
0.2 → 0.3 → 0.4 → 对比质量和耗时

5. 适用场景

TeaCache 更适合:

代码语言:javascript
复制
1. DiT / Transformer-based diffusion 模型
2. 图像生成、视频生成、音频扩散生成
3. denoising steps 较多的推理任务
4. 对 1.5x~2x 加速有价值,同时能容忍极小质量波动的生产服务
5. 单卡加速场景

vLLM-Omni 官方文档也建议 TeaCache 用于需要更快推理、且能容忍极小质量损失的生产场景;不太适合极致画质要求或非常短步数推理,例如小于 20 steps 的情况。

不太适合:

代码语言:javascript
复制
1. 4~8 step 的蒸馏模型
2. 强编辑、强控制、强文字生成场景
3. 对画质零损失要求极高的任务
4. 主要瓶颈不在 Transformer 的 pipeline

6. vLLM-Omni 正式用法

如果你用的是 vLLM-Omni,并且模型后端支持 TeaCache,可以直接这样开:

代码语言:javascript
复制
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

omni = Omni(
    model="Qwen/Qwen-Image",
    cache_backend="tea_cache",
)

outputs = omni.generate(
    "A cat sitting on a windowsill",
    OmniDiffusionSamplingParams(num_inference_steps=50),
)

在线服务方式:

代码语言:javascript
复制
vllm serve Qwen/Qwen-Image --omni --port 8091 \
  --cache-backend tea_cache \
  --cache-config '{"rel_l1_thresh": 0.2}'

这些参数写法来自 vLLM-Omni TeaCache 官方文档。

使用 facebook/DiT-XL-2-256 重点演示“根据相邻 step 输入差异决定是否复用 residual”。 它不是官方 TeaCache 的完整实现,因为官方 TeaCache 会使用 timestep embedding 调制输入、模型专属 coefficients、多项式 rescale 等细节。

安装环境

代码语言:javascript
复制
pip install -U diffusers transformers accelerate safetensors scipy pillow

代码

代码语言:javascript
复制
import gc
import time
import types

import torch
from IPython.display import display
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
from diffusers.models.modeling_outputs import Transformer2DModelOutput


MODEL_ID = "facebook/DiT-XL-2-256"
NUM_STEPS = 25
CLASS_ID = 207# ImageNet class id.


class SimpleTeaCacheState:
    """Minimal TeaCache-style state for DiT transformer calls."""

    def __init__(self, rel_l1_thresh=0.20, num_steps=25):
        self.rel_l1_thresh = rel_l1_thresh
        self.num_steps = num_steps
        self.reset()

    def reset(self):
        self.step_idx = 0
        self.accumulated_rel_l1 = 0.0
        self.previous_input = None
        self.previous_residual = None
        self.previous_sample = None
        self.full_compute_steps = 0
        self.cached_steps = 0

    @torch.no_grad()
    def should_compute(self, hidden_states: torch.Tensor) -> bool:
        is_first = self.step_idx == 0
        is_last = self.step_idx >= self.num_steps - 1

        if is_first or is_last:
            self.previous_input = hidden_states.detach().float()
            self.accumulated_rel_l1 = 0.0
            returnTrue

        if self.previous_input isNoneor self.previous_sample isNone:
            self.previous_input = hidden_states.detach().float()
            returnTrue

        current = hidden_states.detach().float()
        previous = self.previous_input

        denom = previous.abs().mean().clamp_min(1e-6)
        rel_l1 = (current - previous).abs().mean() / denom
        self.accumulated_rel_l1 += float(rel_l1.item())

        if self.accumulated_rel_l1 < self.rel_l1_thresh:
            returnFalse

        self.accumulated_rel_l1 = 0.0
        self.previous_input = current
        returnTrue

    def next_step(self):
        self.step_idx += 1


def clear_cuda_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()


def load_dit_pipeline(device):
    dtype = torch.float16 if device == "cuda"else torch.float32
    pipe = DiTPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        token=False,
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    return pipe.to(device)


def enable_simple_teacache_for_dit(pipe, rel_l1_thresh=0.20, num_steps=25):
    state = SimpleTeaCacheState(
        rel_l1_thresh=rel_l1_thresh,
        num_steps=num_steps,
    )

    transformer = pipe.transformer
    original_forward = transformer.forward

    def cached_forward(self, hidden_states, timestep, class_labels=None, **kwargs):
        compute_full = state.should_compute(hidden_states)

        ifnot compute_full and state.previous_sample isnotNone:
            if state.previous_residual isnotNoneand state.previous_residual.shape == hidden_states.shape:
                sample = hidden_states + state.previous_residual.to(
                    device=hidden_states.device,
                    dtype=hidden_states.dtype,
                )
            else:
                sample = state.previous_sample.to(
                    device=hidden_states.device,
                    dtype=hidden_states.dtype,
                )
            state.cached_steps += 1
            state.next_step()
            return Transformer2DModelOutput(sample=sample)

        out = original_forward(
            hidden_states=hidden_states,
            timestep=timestep,
            class_labels=class_labels,
            **kwargs,
        )

        state.previous_sample = out.sample.detach()
        if out.sample.shape == hidden_states.shape:
            state.previous_residual = out.sample.detach() - hidden_states.detach()
        else:
            state.previous_residual = None
            print("Warning: output sample shape differs from input hidden_states shape; cannot compute residual for caching.")
        state.full_compute_steps += 1
        state.next_step()
        return out

    transformer.forward = types.MethodType(cached_forward, transformer)
    return state


def run_generation(pipe, device):
    generator = torch.Generator(device=device).manual_seed(42)
    return pipe(
        class_labels=[CLASS_ID],
        num_inference_steps=NUM_STEPS,
        generator=generator,
    ).images[0]


def main():
    device = "cuda"if torch.cuda.is_available() else"cpu"
    print(f"device = {device}")

    pipe = load_dit_pipeline(device)

    clear_cuda_cache()
    t0 = time.time()
    image = run_generation(pipe, device)
    baseline_time = time.time() - t0
    baseline_image = image
    baseline_image.save("dit_baseline.png")
    print("Baseline image:")
    display(baseline_image)
    print(f"[Baseline] time = {baseline_time:.2f}s")

    del pipe
    clear_cuda_cache()

    pipe = load_dit_pipeline(device)
    state = enable_simple_teacache_for_dit(
        pipe,
        rel_l1_thresh=0.20,
        num_steps=NUM_STEPS,
    )

    clear_cuda_cache()
    t0 = time.time()
    image = run_generation(pipe, device)
    cached_time = time.time() - t0
    cached_image = image
    cached_image.save("dit_simple_teacache.png")
    print("SimpleTeaCache image:")
    display(cached_image)

    print(f"[SimpleTeaCache] time = {cached_time:.2f}s")
    print(f"full_compute_steps = {state.full_compute_steps}")
    print(f"cached_steps       = {state.cached_steps}")
    print(f"speedup            = {baseline_time / cached_time:.2f}x")


if __name__ == "__main__":
    main()

输出:

代码语言:javascript
复制
[Baseline] time = 1.42s
[SimpleTeaCache] time = 0.57s
full_compute_steps = 9
cached_steps       = 16
speedup            = 2.48x

对比

threshold

time(s)

speedup

full

cached

0.1

0.59

2.80x

10

15

0.2

0.68

2.43x

9

16

0.3

0.42

3.89x

7

18

0.4

0.40

4.12x

6

19

0.5

0.42

3.98x

6

19

0.6

0.39

4.19x

6

19

0.7

0.38

4.38x

5

20

0.8

0.35

4.77x

5

20

0.9

0.33

4.94x

5

20

1.0

0.36

4.63x

5

20

如何调参?

先用:

代码语言:javascript
复制
rel_l1_thresh = 0.20

然后逐步尝试:

代码语言:javascript
复制
0.10:更稳,速度提升较小
0.20:均衡起点
0.30:更快,但可能有质量波动
0.50:偏激进,容易出画质问题

如果发现生成图像细节变差、主体变形、纹理糊,先把阈值降回:

代码语言:javascript
复制
rel_l1_thresh = 0.10

7. 工程踩坑清单

坑 1:把 TeaCache 当成 LLM KV Cache

LLM KV Cache 缓存的是 token 历史的 key/value。

TeaCache 缓存的是 diffusion denoising step 之间的中间输出 / residual。

二者不是一个东西。


坑 2:rel_l1_thresh 太大

表现:

代码语言:javascript
复制
图像细节糊
视频运动不稳定
人物脸部漂移
文字生成质量下降
编辑任务不稳定

解决:

代码语言:javascript
复制
cache_config = {"rel_l1_thresh": 0.1}

vLLM-Omni 文档在质量下降场景下也建议降低 threshold,使缓存更保守


坑 3:步数太少时加速不明显

TeaCache 需要足够多的 denoising steps 才有跳过空间。 vLLM-Omni 文档也提到,非常短的推理过程,例如小于 20 steps,缓存开销可能抵消收益;如果加速低于预期,建议使用足够多的 inference steps,例如 35+


坑 4:coefficients 不能乱迁移

官方 TeaCache 里有模型相关的多项式 coefficients。不同模型的 timestep embedding、Transformer 结构、scheduler 都可能不同。

TeaCache 官方仓库也提示:结构相近的模型可以尝试迁移 coefficients,否则需要参考已有适配或重新适配。


坑 5:服务化时 cache state 必须按请求隔离

不要在多用户服务里把这些状态做成全局变量:

代码语言:javascript
复制
previous_residual
previous_input
accumulated_rel_l1
step_idx

否则请求 A 的 cache 可能污染请求 B。

正确做法:

代码语言:javascript
复制
每个 request 独立 TeaCacheState
请求结束后 reset
CFG cond/uncond 分支分别维护 cache
多线程/异步场景避免共享状态

8. 总结

TeaCache 的本质可以压缩成一句话:

❝利用 timestep embedding 感知相邻 denoising step 的输出变化;变化小时复用缓存,变化大时完整计算。

它最适合 DiT 类图像/视频/音频扩散模型,尤其是 denoising steps 较多、Transformer 计算占主要瓶颈的场景。生产里建议从 rel_l1_thresh=0.2 开始,逐步调大,同时用固定 prompt、seed、分辨率和 steps 对比不开缓存、保守缓存、激进缓存三组结果。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2026-05-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Michael阿明 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 1. 为什么 TeaCache 能加速?
  • 2. TeaCache 判断“能不能跳过”的依据
  • 3. 普通推理 vs TeaCache 推理
  • 4. 最关键参数:rel_l1_thresh
  • 5. 适用场景
  • 6. vLLM-Omni 正式用法
    • 安装环境
    • 代码
    • 如何调参?
  • 7. 工程踩坑清单
    • 坑 1:把 TeaCache 当成 LLM KV Cache
    • 坑 2:rel_l1_thresh 太大
    • 坑 3:步数太少时加速不明显
    • 坑 4:coefficients 不能乱迁移
    • 坑 5:服务化时 cache state 必须按请求隔离
  • 8. 总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档