
AI Toolkit 是由 Ostris 开发的一站式扩散模型训练套件,旨在支持消费级硬件上的最新图像和视频扩散模型训练。该项目既可作为图形界面(GUI)也可作为命令行工具(CLI)运行,设计理念是简单易用但功能全面。
核心价值:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.config_modules import ModelConfig
# 初始化模型
model_config = ModelConfig(
name_or_path="runwayml/stable-diffusion-v1-5",
is_v2=False,
dtype="fp16"
)
sd = StableDiffusion(
device="cuda",
model_config=model_config,
dtype="fp16"
)
# 加载模型
sd.load_model()
# 训练配置
train_config = {
"learning_rate": 1e-5,
"max_train_steps": 1000,
"train_batch_size": 4
}
# 开始训练
sd.train(train_config)from toolkit.config_modules import GenerateImageConfig
gen_config = GenerateImageConfig(
prompts=["A beautiful sunset over mountains"],
width=512,
height=512,
guidance_scale=7.5,
num_inference_steps=50
)
images = sd.generate_images(gen_config)
images[0].save("sunset.png")def load_model(self):
"""加载扩散模型组件"""
# 加载文本编码器
self.text_encoder = CLIPTextModel.from_pretrained(
self.model_config.name_or_path,
subfolder="text_encoder",
torch_dtype=self.torch_dtype
)
# 加载VAE
self.vae = AutoencoderKL.from_pretrained(
self.model_config.name_or_path,
subfolder="vae",
torch_dtype=self.torch_dtype
)
# 加载UNet
self.unet = UNet2DConditionModel.from_pretrained(
self.model_config.name_or_path,
subfolder="unet",
torch_dtype=self.torch_dtype
)def train_loop(self, dataloader, optimizer, lr_scheduler):
"""训练主循环"""
for epoch in range(self.train_config.num_epochs):
for batch in dataloader:
# 前向传播
latents = self.vae.encode(batch["pixel_values"]).latent_dist.sample()
noise = torch.randn_like(latents)
timesteps = torch.randint(
0, self.noise_scheduler.num_train_timesteps,
(latents.shape[0],), device=self.device
)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# 获取文本嵌入
text_embeddings = self.text_encoder(batch["input_ids"])[0]
# 预测噪声
noise_pred = self.unet(
noisy_latents, timesteps, text_embeddings
).sample
# 计算损失
loss = F.mse_loss(noise_pred, noise)
# 反向传播
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。