解读:AI生成未来

文章链接:https://arxiv.org/pdf/2509.21318

第一印象:4 步模型中的高保真样本
亮点直击
提出了 SD3.5-Flash,一个高效的少步蒸馏框架,核心包括两项算法创新:
流匹配。扩散模型是一类生成模型,其学习一条从(高斯)噪声到数据的轨迹,并通过迭代地遵循该轨迹从采样的噪声生成媒体。这条从噪声到数据的轨迹通常在基于分数的生成框架中被建模为一个随机微分方程(SDE)的解,并且可以被重新表述为一个常微分方程(ODE),即概率流ODE(PF-ODE)。基于分数的生成框架中的扩散模型学习一个分数函数——即对数概率密度的梯度——通过训练一个神经网络来估计其在轨迹上不同噪声水平处的值。更新方向可以定义为:

其中 被称为 的分数函数,并由神经网络参数化为 ,在概率流ODE(PF-ODE,Karras等人,2022)中,。相比之下,流匹配(Lipman等人,2022;Esser等人,2024)模型定义了一类独立的生成方法,其直接学习基于ODE的映射,而不依赖于底层的SDE。这些模型参数化了一个速度场,该速度场沿着ODE定义的轨迹将样本从噪声传输到数据。流匹配的更新方向变为 ,其中速度 由网络参数化为 。在如 SD3.5 Medium(Stability AI, 2024)这样的整流 pipeline (Liu等人,2022)中,样本按照数据分布和标准正态分布 之间的直线路径添加噪声:。
分布匹配蒸馏。DMD(Yin等人,2024b)提出通过将学生分布 与教师分布 进行匹配,来将一个多步教师模型 蒸馏成一个蒸馏后的单步学生模型 。给定一个样本 ,其中 ,这种分布匹配被计算为库尔巴克-莱布勒散度(KL散度):

然而,由于概率密度通常是难以处理的,直接使用该散度作为损失函数是不可能的。由于只需要该损失的梯度,这可以通过代入分数函数 并计算损失梯度来规避:

为了获得这些分数,生成的样本 被重新加噪至时间步 ,即 。然后,分数根据预训练扩散模型的去噪信号计算,教师分数为 ,学生分数为 ,其中学生分数 来自学生模型 。由于少步模型仅在一部分时间步上工作,因此需要维护一个多步代理模型,该模型监控少步模型的分布并充当替代的学生分数估计器。为了稳定此流程,LDMD 辅以回归损失,该损失计算为从相同噪声开始,由学生和教师生成的图像之间的均方误差(MSE)。DMD2(Yin等人,2024a)提出使用有偏调度来更新学生代理 ,以提高稳定性,而无需引入此回归损失,并用对抗目标来补充 LDMD。
为了稳定地预训练我们的4步学生网络,使用轨迹引导目标 。对于教师模型轨迹上的时间步 ,我们对与学生轨迹重合的点 进行子采样(即对于4步模型,),并将轨迹引导目标计算为:

其中 对应于速度预测器教师模型,而 是正在训练的学生模型。
本文使用公式(3)中的 DMD 目标来微调我们预训练的学生模型,该目标通过代理()计算教师和学生分布之间 KL 散度的梯度。为了使代理和学生的分布对齐,从而通过在对生成的学生样本 进行微调来实现 LDMD 中学生分布的准确表示,我们对生成的学生样本 进行微调。具体来说,将终点估计值 加噪至 ,并计算流匹配损失为 ,其中 来自所添加的噪声。为了在时间步 ()训练学生模型 ,我们禁用梯度并使用学生模型本身生成直到 的结果。与 Yin 等人(2024a)不同,我们发现,对于时间步 ,直接在稍微嘈杂的样本 上开始训练,相比在样本 上训练,能提高性能。在训练稳定后,我们切换回在 上训练时间步 ,类似于 Yin 等人提出的“后向模拟”。
时间步共享。公式(3)中的 DMD 目标需要将样本从 加噪至 以分别计算真实分数和伪造分数 和 。在基于分数的模型中,这是通过向样本添加随机噪声来实现的,而这已经是去噪循环的一部分。然而,预训练的基于流的模型具有匹配的图像-噪声对,为达到时间步 而添加随机噪声可能会产生嘈杂的梯度更新。我们通过将 DMD 时间步与少步去噪调度中的时间步共享来简化训练目标并防止噪声添加。
本文评估 KL 散度梯度时,不是通过从轨迹端点(即公式(3)中的 到 )重新加噪,而是简单地使用学生轨迹上的部分去噪样本()进行速度估计。直观地说,我们计算假设的“伪” 的分数,该 被加噪至 ,而不是估计 本身(见下图3)。这减少了来自嘈杂时间步(在 时)的劣质 估计所导致的低质量梯度。因此,这迫使我们将分布匹配的时间步与学生轨迹的时间步 共享,而不是使用公式(3)中的随机 。虽然这确实导致时间步的变化减少(仅使用学生轨迹中的少数几个时间步),但我们发现它提高了图像构图和生成质量。

分时间步微调。时间步蒸馏常常会削弱文本提示词与生成输出之间的对应关系。为了抵消这一点,我们设计了分时间步微调,其灵感来源于先前利用扩散模型进行多任务学习的工作。首先将预训练模型复制到分支 和 中,并分别在不相交的时间步范围 和 上对它们进行训练,以增加有效模型容量。在微调期间,每个分支使用衰减率为 的指数移动平均来稳定权重并使其接近原始检查点。收敛后,我们通过权重插值融合分支,选择 3:7 的比例()以最大化 GenEval测量的文本提示词对齐度。我们仅在对我们的四步模型进行训练时执行分时间步微调,在那里我们观察到模型性能有显著提升。
本文使用一个对抗目标,其中代理学生 充当特征提取器以获取判别器特征。这使我们能够在流潜在空间上进行对抗训练,而不是在中的图像空间。为了使用 提取特征,我们将样本 加噪到时间步 上预定义的噪声水平,并从 的多个层提取中间输出作为特征图。时间步 在 区间内均匀分布,以捕获粗粒度特征()和细粒度特征()。在这些特征之上训练 MLP 判别器头 进行真/假预测,其中由教师模型生成的合成样本被用作“真实”数据。与 NitroFusion类似,通过重新初始化判别器头的权重来定期刷新它们,以减少过拟合。我们使用标准的非饱和 GAN 目标来训练判别器头和生成器 :

其中判别器头 (上图3)和特征提取器统称为 。
为了训练一个两步生成器,我们逐步将一个多步教师模型蒸馏到一个四步学生模型,并继续训练它以进行两步推理。我们首先使用来自多步教师模型的预训练权重来初始化我们的教师、学生和代理学生。接下来,我们执行两个训练阶段:(i)使用 预训练学生模型,其中模型被优化以在少量步骤内复制教师轨迹。(ii)在第二阶段,我们最小化教师和学生分布的 KL 散度 ,并辅以我们多头判别器的对抗目标。第一阶段的训练有助于对齐教师和学生的轨迹,并显著加快下一阶段的训练速度。第二阶段构建清晰的特征和详细的图像。我们使用训练好的四步模型作为预训练检查点,按照我们训练流程的第二阶段将其蒸馏到两步。在此过程中,我们还使用教师和学生模型样本特征 Gram 矩阵之间的 MSE 目标。
在 Stable Diffusion 3.5 pipeline 之上执行推理优化。该 pipeline 除了 MM-DiT 扩散模型和 VAE之外,还包括三个文本编码器(CLIP-L、CLIP-G和 T5-XXL)。其中,T5-XXL 是最大的组件,占用了峰值 VRAM 使用量和推理时间的大部分。完整的 16 位精度蒸馏模型需要 18 GiB 的 GPU 内存——这超出了大多数消费级显卡的能力范围。为了降低需求,我们将 MM-DiT 扩散模型量化为 8 位,并利用 SD3.5 中的编码器丢弃预训练来用空嵌入替换 T5-XXL。这将我们的内存需求降低到仅约 8 GiB。为了真正支持手机和平板电脑等边缘设备,我们使用 Apple Silicon 上的 CoreML 将我们的 8 位模型进一步量化为 6 位(下图2)。专门针对此量化,我们重写了 RMSNorm 等操作,以在 Apple Neural Engine 上更好地保持精度。在下表1中总结了我们的优化结果,并强调了在 iPhone(补充压缩包中的视频)和 iPad 等设备上低于 10 秒的延迟。我们在下图8中包含了关于内存性能权衡的更多细节。



数据集与训练。遵循先前的工作,本文使用合成样本来训练我们的模型,因为它们具有高提示连贯性和一致的质量。对于我们的训练数据,我们使用 SD3.5 Large (8B) 模型在 32 个时间步和 CFG 尺度为 4.0 的情况下生成合成样本。我们进行 2K 次迭代的预训练,然后分别使用 2.5B 的 SD3.5M 作为教师模型,对 4 步和 2 步模型各训练 1200 次迭代。2 步模型从 4 步中间检查点开始训练。
基线。为了进行比较,本文考察了以 SDXL作为教师网络训练的 DMD2、Hyper-SD、SDXL-Turbo、Nitrofusion和 SDXL-Lightning。DMD2 通过匹配教师和学生的分布与 KL 散度目标的梯度来蒸馏 SDXL。Hyper-SD 通过轨迹引导执行一致性蒸馏,并使用人类反馈学习来提高性能。SDXL-Turbo 在 Dino-V2的丰富语义空间中展示了对抗蒸馏,在整个训练过程中将潜在变量解码为图像。SDXL-Lightning 也使用对抗蒸馏,但通过在判别器中混合使用条件目标和无条件目标来放宽对学生的模式覆盖要求。Nitrofusion 通过多判别器设置和周期性判别器刷新来稳定对抗蒸馏,并在 SDXL-DMD2 和 SDXL-HyperSD 上进行训练。相较于 SDXL 和 SDv2.1,最近的模型如 SD3.5和 SANA通过采用整流流 pipeline 以实现更快收敛,提供了更好的生成质量和更高的提示遵循度。SWD通过训练一个尺度感知网络来蒸馏 SD3.5M,并使用分布匹配目标进行优化。SANA-Sprint使用连续时间一致性蒸馏将 SANA 蒸馏到 1、2 和 4 步模型。我们还包括与 TensorArt Studios发布的 SD3.5M-Turbo 的比较,它是基于 SD3.5M 的一个独立检查点。我们不与难以装入消费级硬件的大型模型(如 SD3.5 Large (8B) 和 Flux.1-dev(12B))进行比较。
下图 5 中包含了我们的模型(SD3.5-Flash 16-bit + T5)与其他少步生成流程(如 SANA-Sprint1.6B、NitroFusion、SDXL-DMD2 和 SDXL-Lightning)的定性比较,并在附录中提供了更多比较(包括 SWD)。来自 SDXL-DMD2、SDXL-Lightning和 NitroFusion的 4 步结果显示,在涉及人物互动的复杂提示中,提示对齐和构图效果较差。SDXL-Lightning(Lin等人,2024)生成的图像平滑但缺乏锐度且细节不足,有时会产生伪影(例如最后一行最后一列,沙发上的两只柯基犬)。SDXL-DMD2和 NitroFusion(从 SDXL-DMD2 蒸馏而来)生成的纹理更好,但在构图方面同样表现较差,并导致伪影(第二行,书上的猫和第一行,三只猫头鹰)。相比之下,我们的方法(4 步)始终生成高质量图像,并在生成保真度上显著优于其他 4 步流程。在 2 步流程中,我们与 SANA-Sprint 1.6B(Chen等人,2025)进行比较。SANA-Sprint生成了更多细节但风格不一致,有时在没有风格提示的情况下生成风格化图像(第一列和第三列)。SANA-Sprint在非特写环境中也会生成模糊的面部特征(见第四行)。我们的 2 步方法在生成保真度上优于 SANA-sprint,但落后于我们的 4 步模型(第三行缺失的书和第四行的伪影)。在下图 4 中还提供了我们的 4 步 16 位模型使用和不使用 T5 的示例。


基于图像质量和提示词对齐进行了一项用户研究,共有124名标注者参与评估使用4个不同种子生成的图像。为了生成样本,我们使用了一个包含507个提示词的多样化精选集,这些提示词由专家设计的提示词和Parti提示词的一个子集组成。对于每个生成的样本,3名用户对来自两种不同方法的两张图像进行投票,从视觉质量和图像-提示词相关性(提示词遵循度)两个方面对它们进行评分。从用户研究(下图6)中,SD3.5-Flash在图像质量上优于其他少步模型,甚至优于50步的教师模型。在提示词遵循度方面,所有方法之间的差异很小(< ±1.6%)。

本文还比较了选定的竞争对手以计算ELO分数(见上图2)。在所有计算场景中,我们的模型都位于ELO排行榜的顶端,展示了在各种计算预算下的高质量图像生成能力。
我们进行了广泛的定量验证(下表2),为来自COCO数据集的标题生成了30K个样本,其中我们使用了ImageReward、CLIPScore、FID和美学评分等指标来量化生成性能。ImageReward(IR)和美学评分(AeS)是人类偏好指标,经过训练以反映人类对图像质量的偏好。像CLIPScore和FID这样的指标分别用于量化文本对齐度和与真实图像的相似度。CLIPScore测量的是文本提示词与生成图像在CLIP ViT-B/32语义空间中的相似性。FID计算的是生成图像和真实图像(此处来自COCO)的分布在Inception-V3特征空间中的距离。我们还比较了GenEval得分,该指标在不同设置下生成特定对象的图像,并使用对象检测框架评估以识别文图对齐度。我们使用这些指标以及相应的延迟(即在RTX 4090 GPU上以16位浮点精度(BF16)生成一个样本所需的时间,除非另有说明)与所有基线和竞争对手进行比较。

从上表2中,我们发现我们的方法在文图生成方面与SDXL-DMD2和NitroFusion等近期工作相比具有竞争力,同时在GenEval、AeS和IR等指标上超过了教师模型SD3.5M。尽管是在相同的COCO-30K数据集上计算,我们注意到我们的FID较差,而其他指标具有竞争力的分数。我们将此归因于教师模型SDXL和SD3.5M本身的FID差异,并注意到基于SD3.5M训练的SD3.5M-Turbo和SWD平均具有更差的FID。
通过在我们的流程中逐个移除组件来蒸馏 SD3.5M(16 位 4 步)进行消融实验(下图 7),展示了它们对生成保真度的重要性。具体来说,我们蒸馏了以下模型:(i)无对抗目标:不使用 GAN 训练来指导生成,(ii)无预训练:不预训练学生生成器 ,(iii)无时间步共享:在 LDMD 中对 使用随机时间步 而不是学生轨迹上的时间步,以及(iv)无判别器刷新:不定期重新初始化判别器头以纠正过拟合。我们使用与学生模型相同的迭代次数训练消融模型。我们发现移除对抗目标会使训练不稳定,导致生成质量差。没有预训练时,颜色和构图受影响最大。没有时间步共享的训练也会导致纹理、颜色和构图差。最后,没有判别器刷新时,我们发现存在轻微的构图错误和图像过度平滑。

与所有蒸馏过程一样,在复杂生成任务中,我们以推理速度为代价,在质量和多样性的某些方面进行了权衡。我们发现,为了更快的推理和更低的内存而移除 T5,也会因条件上下文变差而难以构建复杂的构图(下图 4)。然而,这些限制并非我们方法所独有,而是用低步数模型近似扩散轨迹的自然结果。尽管如此,我们发现我们的 4 步模型相比教师模型实现了高达约 18 倍的加速,并在包含不同复杂度提示词的大规模用户研究中,其平均性能超过了教师模型。

[1] SD3.5-Flash: Distribution-Guided Distillation of Generative Flows
如果您觉得这篇文章对你有帮助或启发,请不吝点赞、在看、转发,让更多人受益。同时,欢迎给个星标⭐,以便第一时间收到我的最新推送。每一个互动都是对我最大的鼓励。让我们携手并进,共同探索未知,见证一个充满希望和伟大的未来!