许可证:apache-2.0
标签:
- pytorch
- diffusers
- 无条件图像生成
潜在扩散模型(LDM)
论文:高分辨率图像合成的潜在扩散模型
摘要:
通过将图像生成过程分解为去噪自编码器的顺序应用,扩散模型(DMs)在图像数据及其他领域实现了最先进的合成效果。此外,其公式化允许通过引导机制控制图像生成过程而无需重新训练。然而,由于这些模型通常直接在像素空间操作,强大DMs的优化往往消耗数百个GPU天,且由于顺序评估导致推理成本高昂。为了在有限计算资源下训练DMs同时保持其质量和灵活性,我们在强大预训练自编码器的潜在空间中应用它们。与之前工作不同,在此类表示上训练扩散模型首次实现了复杂度降低与细节保留之间的近乎最优平衡,极大提升了视觉保真度。通过在模型架构中引入交叉注意力层,我们将扩散模型转变为强大且灵活的生成器,支持文本或边界框等通用条件输入,并以卷积方式实现高分辨率合成。我们的潜在扩散模型(LDMs)在图像修复任务上达到了新的技术巅峰,并在无条件图像生成、语义场景合成和超分辨率等多项任务中表现出极具竞争力的性能,同时相比基于像素的DMs显著降低了计算需求。
作者
Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer
使用方法
使用管道进行推理
!pip install diffusers
from diffusers import DiffusionPipeline
model_id = "CompVis/ldm-celebahq-256"
pipeline = DiffusionPipeline.from_pretrained(model_id)
image = pipeline(num_inference_steps=200)["sample"]
image[0].save("ldm_generated_image.png")
使用展开循环进行推理
!pip install diffusers
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import PIL.Image
import numpy as np
import tqdm
seed = 3
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
unet.to(torch_device)
vqvae.to(torch_device)
generator = torch.manual_seed(seed)
noise = torch.randn(
(1, unet.in_channels, unet.sample_size, unet.sample_size),
generator=generator,
).to(torch_device)
scheduler.set_timesteps(num_inference_steps=200)
image = noise
for t in tqdm.tqdm(scheduler.timesteps):
with torch.no_grad():
residual = unet(image, t)["sample"]
prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
image = prev_image
with torch.no_grad():
image = vqvae.decode(image)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save(f"generated_image_{seed}.png")
示例



