模型简介
模型特点
模型能力
使用案例
license: apache-2.0 pipeline_tag: text-to-image
Simple Diffusion XS
极致小巧,超凡品质
训练状态(已暂停):第16个训练周期
使用示例
import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from tqdm.auto import tqdm
import os
def encode_prompt(prompt, negative_prompt, device, dtype):
if negative_prompt is None:
negative_prompt = ""
with torch.no_grad():
positive_inputs = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
positive_embeddings = text_model.encode_texts(
positive_inputs.input_ids, positive_inputs.attention_mask
)
if positive_embeddings.ndim == 2:
positive_embeddings = positive_embeddings.unsqueeze(1)
positive_embeddings = positive_embeddings.to(device, dtype=dtype)
negative_inputs = tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
if negative_embeddings.ndim == 2:
negative_embeddings = negative_embeddings.unsqueeze(1)
negative_embeddings = negative_embeddings.to(device, dtype=dtype)
return torch.cat([negative_embeddings, positive_embeddings], dim=0)
def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
with torch.no_grad():
device, dtype = embeddings.device, embeddings.dtype
half = embeddings.shape[0] // 2
latent_shape = (half, 16, height // 8, width // 8)
latents = torch.randn(latent_shape, device=device, dtype=dtype)
embeddings = embeddings.repeat_interleave(half, dim=0)
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps, desc="生成中"):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(latents, vae, output_type="pil"):
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
images = vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
images = (images * 255).round().astype("uint8")
images = [Image.fromarray(image) for image in images]
return images
# 使用示例:
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
prompt = "猫"
negative_prompt = "低质量"
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
text_model = AutoModel.from_pretrained(
"visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
).to(device, dtype=dtype).eval()
embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
pipeid = "AiArtLab/sdxs"
variant = "fp16"
unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
height, width = 576, 576
num_inference_steps = 40
output_folder, project_name = "samples", "sdxs"
latents = generate_latents(
embeddings=embeddings,
height=height,
width=width,
num_inference_steps = num_inference_steps
)
images = decode_latents(latents, vae)
os.makedirs(output_folder, exist_ok=True)
for idx, image in enumerate(images):
image.save(f"{output_folder}/{project_name}_{idx}.jpg")
print("图像已生成并保存至:", output_folder)
项目介绍
快速、轻量级、多语言扩散模型,为所有人打造
我们是AiArtLab,一个预算有限的小型爱好者团队。我们的目标是创建一个能在消费级显卡上完成完整训练周期(非LoRA)的紧凑快速模型。选择U-Net架构是因为它能高效处理小数据集,即使在16GB GPU(如RTX 4080)上也能快速训练。我们的预算只有几千美元,远低于SDXL等竞争对手(数千万美元),因此决定打造一个类似SD1.5但面向2025年的小巧高效模型。
编码器架构(文本与图像)
我们测试了多种编码器,发现像LLaMA或T5 XXL这样的大型模型对高质量生成并非必需。但我们需要能理解查询上下文的编码器,专注于"提示理解"而非"提示跟随"。最终选择了支持80种语言、以句子为单位处理的多语言编码器Mexma-SigLIP。Mexma支持512个token,会生成减缓训练速度的大矩阵,因此我们通过池化层将512x1152矩阵简化为1x1152向量,并经过线性模型/文本投影器实现与SigLIP嵌入的兼容。这种设计使文本嵌入能与图像同步,有望实现统一的多模态模型,支持在查询中混合图像嵌入与文本描述。该模型还可仅用图像进行无文本训练,这对视频生成(标注困难)特别有利——通过输入带衰减系数的前一帧嵌入,可实现更连贯的视频生成。未来我们计划将该模型扩展至3D/视频生成领域。
U-Net架构
我们采用平滑通道金字塔结构:[384, 576, 768, 960](每块两层)和[4, 6, 8, 10]个含24个注意力头(1152/48)的transformer。这个约20亿参数的架构在RTX 4080上实现了最佳训练速度。我们认为其更大的"深度"能使质量媲美SDXL,尽管"尺寸"更小。通过添加1152层可扩展至40亿参数,实现与嵌入尺寸的完美对称——这种优雅设计可能达到"Flux/MJ级别"的质量。
VAE架构
我们选择了非常规的8x16通道AuraDiffusion VAE,它能保留细节、文字和解剖结构,没有SD3/Flux特有的"雾化"效果。采用带FFN卷积的快速版本后,发现精细图案会出现轻微纹理损伤,这可能影响基准测试评分。不过ESRGAN等超分工具能修复这些伪影。总体而言,我们认为这款VAE被严重低估了。
训练过程
优化器
测试了AdamW、Laion、Optimi-AdamW、Adafactor和AdamW-8bit后,最终选择AdamW-8bit。虽然Optimi-AdamW的梯度衰减曲线最平滑,但AdamW-8bit的较小体积允许更大批次,在低成本GPU(我们使用4xA6000和5xL40s)上实现了最大训练速度。
学习率
调整衰减/预热曲线有效果但不显著。实验表明Adam允许较宽的学习率范围,我们从1e-4开始,逐步降至1e-6。这说明正确的模型架构比调参重要得多。
数据集
使用约100万张图像训练:在256分辨率ImageNet上训练60个周期(因低质量标注浪费了时间),在CaptionEmporium/midjourney-niji-1m-llavanext上训练8个周期,外加576分辨率的真实照片和动漫/艺术作品。采用人工提示、Caption Emporium提供的提示、SmilingWolf的WD-Tagger和Moondream2进行标注,通过变化提示长度和组合确保模型理解不同提示风格。极小数据集导致模型会遗漏许多实体(如"自行车上的鹅"这类未见概念)。数据集中包含大量二次元风格图像,因为我们更关注模型学习人体解剖的能力而非"马上宇航员"这类技能。虽然多数描述是英文,但测试表明模型具备多语言能力。
局限性
- 极小数据集导致概念覆盖有限
- 图生图功能需进一步训练(我们将SigLIP占比降至5%以专注文生图)
致谢
- Stan - 主要投资人。感谢在众人认为我们疯狂时仍给予信任。
- Captainsaturnus - 物资支持
- Lovescape与Whargarbl - 精神支持
- CaptionEmporium - 数据集提供
"我们相信未来属于高效紧凑的模型。感谢所有捐助,期待继续支持。"
训练预算
捐助支持
如有意向提供GPU或资金支持训练,请联系我们
狗狗币:DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83 比特币:3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
联系方式









