许可证: mit
数据集:
- Artificio/WikiArt
管道标签: 无条件图像生成
模型描述
该模型是基于Wikiart数据集(https://huggingface.co/datasets/Artificio/WikiArt)从头训练的DiT(扩散变换器)模型,旨在根据艺术流派和艺术风格生成艺术作品图像。
模型架构
该模型大体上遵循了论文《Scalable Diffusion Models with Transformers》中描述的经典DiT架构,并进行了少量修改:
- 将ImageNet类别嵌入替换为Wikiart流派和风格嵌入;
- 使用后归一化(post-norm)而非前归一化(pre-norm);
- 省略了最终的线性层;
- 将sin-cos-2d位置嵌入替换为可学习的位置嵌入;
- 模型仅预测噪声,不学习sigma;
- 所有模型变体的patch_size均设置为2;
- 模型有不同的尺寸设置。
如需了解更多细节,请查看本仓库中的modeling_dit_wikiart.py文件。
模型包含三种变体:
- S(小型):num_blocks=8,hidden_size=384,num_heads=6,总参数量=20M;
- B(基础型):num_blocks=12,hidden_size=640,num_heads=10,总参数量=90M;
- L(大型):num_blocks=16,hidden_size=896,num_heads=14,总参数量=234M。
训练过程
- 数据集:所有模型变体均在经过水平翻转数据增强的103K Wikiart数据集上训练。
- 优化器:使用默认设置的AdamW。
- 学习率:在前1%的训练步数中线性预热,学习率最高达到3e-4,随后通过余弦衰减降至零。
- 训练周期和批次大小:
- S:96个周期,批次大小为176;
- B:120个周期,批次大小为192;
- L:144个周期,批次大小为192。
- 训练设备:
- S:单块RTX 4060ti 16G显卡,训练24小时;
- B:单块RTX 4060ti 16G显卡,训练90小时;
- L:单块RTX 4090D 24G显卡训练48小时,随后使用单块RTX 4060ti 16G显卡训练100小时。
- 损失曲线:所有变体在第一个周期内损失值从1.0000以上急剧下降至约0.2000,随后缓慢下降,最终在第20个周期达到loss=0.1600。DiT-S最终达到0.1590;DiT-B最终达到0.1525;DiT-L最终达到0.1510。训练过程稳定,未出现损失值突增。
性能与局限性
- 模型展示了初步理解流派和风格并生成视觉吸引力绘画的能力(乍看之下)。
- 局限性包括:
- 无法理解复杂结构(如人脸、建筑物等);
- 当生成数据集中罕见的流派或风格(如极简主义风格或浮世绘流派)时,偶尔会出现模式崩溃;
- 分辨率限制为256x256;
- 由于基于Wikiart数据集训练,无法生成超出该数据集范围的图像。
使用方法
使用模型前,请安装"huggingface_hub"库,并从"Files and versions"中下载modeling_dit_wikiart.py以获取模型定义。之后可通过以下代码使用模型:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
该模型需搭配stabilityai/sd-vae-ft-ema使用。