基础模型:
- THUDM/CogView4-6B
数据集:
- sayapaul/OmniEdit-mini
库名称: diffusers
组件示例:
- 文本: >-
将其改为呈现厚涂绘画风格
输出:
链接: output1.png
- 文本: >-
将场景改为春季开满花的树木
输出:
链接: output2.png
- 文本: >-
将场景转换为暴风雨太空
输出:
链接: output3.png
标签:
- 文生图
- diffusers训练
- diffusers
- 模板:sd-lora
- cogview4
- 微调训练器
<画廊 />
这是一个用于对THUDM/CogView4-6B模型进行图像微调编辑的控制LoRA。
代码库: https://github.com/a-r-r-o-w/finetrainers
[!重要提示]
此为实验性检查点,其泛化能力不足是已知问题。
推理代码:
import torch
from diffusers import CogView4Pipeline
from diffusers.utils import load_image
from finetrainers.models.utils import _expand_linear_with_zeroed_weights
from finetrainers.patches import load_lora_weights
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
dtype = torch.bfloat16
device = torch.device("cuda")
generator = torch.Generator().manual_seed(0)
pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype)
in_channels = pipe.transformer.config.in_channels
patch_channels = pipe.transformer.patch_embed.proj.in_features
pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels)
load_lora_weights(pipe, "finetrainers/CogView4-6B-Edit-LoRA-v0", "cogview4-lora")
pipe.set_adapters("cogview4-lora", 0.9)
pipe.to(device)
prompt = "使图像呈现古埃及壁画风格"
control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png")
height, width = 1024, 1024
with torch.no_grad():
latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator)
control_image = pipe.image_processor.preprocess(control_image, height=height, width=width)
control_image = control_image.to(device=device, dtype=dtype)
control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]):
image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0]
image.save("output.png")