🚀 DiT-Wikiart-Small 模型
本模型是一个基于扩散变压器(Diffusion Transformer)架构的模型,专门用于无条件图像生成。它在Wikiart数据集上从头开始训练,能够根据艺术流派和风格生成艺术图像。
🚀 快速开始
要使用此模型,你需要安装 huggingface_hub
库,并从“文件和版本”中下载 modeling_dit_wikiart.py
用于模型定义。之后,你可以使用以下代码来使用该模型:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Small")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
该模型与 stabilityai/sd-vae-ft-ema
搭配使用。
✨ 主要特性
- 基于扩散变压器(DiT)架构,在Wikiart数据集上从头开始训练。
- 能够根据给定的艺术流派和风格生成艺术图像。
- 模型有三种不同大小的变体,以满足不同的需求。
📦 安装指南
要使用此模型,你需要安装 huggingface_hub
库,并从“文件和版本”中下载 modeling_dit_wikiart.py
用于模型定义。
- 库链接:https://hf-mirror.com/kaupane/DiT-Wikiart-Small
💻 使用示例
基础用法
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Small")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
📚 详细文档
模型描述
本模型是一个在Wikiart数据集(https://huggingface.co/datasets/Artificio/WikiArt )上从头开始训练的DiT(扩散变压器)模型。它旨在根据艺术流派和风格生成艺术图像。
模型架构
该模型在很大程度上借鉴了论文 Scalable Diffusion Models with Transformers 中描述的经典DiT架构,并进行了一些细微的修改:
- 用Wikiart的流派和风格嵌入替换了ImageNet的类别嵌入;
- 使用后归一化(post-norm)代替前归一化(pre-norm);
- 省略了最后的线性层;
- 用学习到的位置嵌入替换了正弦 - 余弦二维位置嵌入;
- 模型仅预测噪声,不学习sigma;
- 所有模型变体的
patch_size
都设置为2;
- 模型有不同的大小设置。
如果你感兴趣,可以查看此仓库中的 modeling_dit_wikiart.py
以获取更多详细信息。
模型有三种变体:
- S:小型,
num_blocks=8
,hidden_size=384
,num_heads=6
,总参数为20M;
- B:基础型,
num_blocks=12
,hidden_size=640
,num_heads=10
,总参数为90M;
- L:大型,
num_blocks=16
,hidden_size=896
,num_heads=14
,总参数为234M。
训练过程
- 数据集:所有模型变体都在103K的Wikiart数据集上进行训练,并通过水平翻转进行数据增强。
- 优化器:使用默认设置的AdamW优化器。
- 学习率:在前1%的步骤中进行线性热身,学习率达到最大值3e-4,然后在后续步骤中进行余弦衰减至零。
- 训练轮数和批次大小:
- S:96轮,批次大小为176;
- B:120轮,批次大小为192;
- L:144轮,批次大小为192。
- 设备:
- S:单张RTX 4060ti 16G,训练24小时;
- B:单张RTX 4060ti 16G,训练90小时;
- L:先使用单张RTX 4090D 24G训练48小时,再使用单张RTX 4060ti 16G训练100小时。
- 损失曲线:所有变体在第一个训练轮次中损失从1.0000以上急剧下降到0.2000左右,随后下降速度变慢,在第20个训练轮次时最终达到损失值0.1600。DiT-S最终达到0.1590;DiT-B最终达到0.1525;DiT-L最终达到0.1510。
性能和局限性
- 模型展示了理解流派和风格并生成视觉上吸引人的绘画的基本能力(乍一看)。
- 局限性包括:
- 无法理解复杂的结构,如人脸、建筑物等。
- 在要求生成数据集中罕见的流派或风格时,偶尔会出现模式崩溃。例如极简主义风格和浮世绘流派。
- 分辨率限制为256x256。
- 由于在Wikiart数据集上训练,因此无法生成超出该数据集范围的图像。
📄 许可证
本模型采用MIT许可证。
📋 模型信息