语言:
- 英文
标签:
- pytorch_model_hub_mixin
- 动画
- 视频帧插值
- 不确定性估计
许可证: mit
管道标签: 图像到图像
🤖 多输入ResShift扩散视频插帧
⚙️ 安装
首先从GitHub直接下载源代码。
git clone https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI.git
创建一个conda环境并安装所有依赖项
conda create -n multi-input-resshift python=3.12
conda activate multi-input-resshift
pip install -r requirements.txt
注意: 确保您的系统兼容 CUDA 12.4。如果不兼容,请根据您当前的CUDA版本安装 CuPy。
🚀 推理示例
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from utils.utils import denorm
from model.hub import MultiInputResShiftHub
model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI").cuda()
model.eval()
img0_path = r"_data\example_images\frame1.png"
img2_path = r"_data\example_images\frame3.png"
mean = std = [0.5]*3
transforms = Compose([
Resize((256, 448)),
ToTensor(),
Normalize(mean=mean, std=std),
])
img0 = transforms(Image.open(img0_path).convert("RGB")).unsqueeze(0).cuda()
img2 = transforms(Image.open(img2_path).convert("RGB")).unsqueeze(0).cuda()
tau = 0.5
img1 = model.reverse_process([img0, img2], tau)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(denorm(img0, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 2)
plt.imshow(denorm(img1, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 3)
plt.imshow(denorm(img2, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.show()