🚀 双边参考高分辨率二分图像分割(BiRefNet)
BiRefNet是用于高分辨率二分图像分割的模型,在图像分割领域有重要应用,能在背景去除、掩膜生成等多个任务中达到SOTA性能。
🚀 快速开始
本仓库是论文 "双边参考高分辨率二分图像分割" (CAAI AIR 2024)的官方实现。访问我们的GitHub仓库 https://github.com/ZhengPeng7/BiRefNet 以获取更多详细信息,包括 代码、文档 和 模型库!
✨ 主要特性
- 多任务SOTA性能:该模型在三个任务(二分图像分割DIS、高分辨率显著目标检测HRSOD和伪装目标检测COD)中取得了SOTA性能。
- 多种使用方式:可以结合HuggingFace的权重和代码使用,也可以结合GitHub的代码和HuggingFace的权重使用,还能使用本地的代码和权重。
- 在线演示:提供了在线单图像推理、带GUI的在线推理以及给定权重的推理和评估等在线演示。
📦 安装指南
0. 安装依赖包
pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
💻 使用示例
基础用法
使用HuggingFace的代码和权重
仅使用HuggingFace上的权重,优点是无需手动下载BiRefNet代码;缺点是HuggingFace上的代码可能不是最新版本(我会尽量保持其为最新版本)。
from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
使用GitHub的代码和HuggingFace的权重
仅使用HuggingFace上的权重,优点是代码始终是最新的;缺点是需要从我的GitHub克隆BiRefNet仓库。
# 下载代码
git clone https://github.com/ZhengPeng7/BiRefNet.git
cd BiRefNet
from models.birefnet import BiRefNet
birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
高级用法
使用本地的代码和权重
同时使用本地的权重和代码。
import torch
from utils import check_state_dict
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
使用加载好的BiRefNet进行推理
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from models.birefnet import BiRefNet
birefnet = ...
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
def extract_object(birefnet, imagepath):
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(imagepath)
input_images = transform_image(image).unsqueeze(0).to('cuda')
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
return image, mask
plt.axis("off")
plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
plt.show()
用于标准二分图像分割(DIS)的BiRefNet在 DIS - TR 上进行训练,并在 DIS - TEs和DIS - VD 上进行验证。
📚 详细文档
本仓库包含了我们论文中提出的BiRefNet的权重。
在线演示
尝试我们的在线推理演示:
- Colab上的在线单图像推理:

- Hugging Face上带GUI的在线推理,分辨率可调节:

- 对给定权重进行推理和评估:

🔧 技术细节
该模型是基于双边参考机制设计的,用于高分辨率二分图像分割。其训练数据为 DIS - TR,验证数据为 DIS - TEs和DIS - VD。
📄 许可证
本项目采用MIT许可证。
致谢
- 非常感谢 @fal 为训练更好的BiRefNet模型在GPU资源上提供的慷慨支持。
- 非常感谢 @not - lain 在HuggingFace上更好地部署我们的BiRefNet模型方面提供的帮助。
引用
@article{BiRefNet,
title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
journal={CAAI Artificial Intelligence Research},
year={2024}
}