library_name: birefnet
tags:
- 背景去除
- 掩膜生成
- 二分图像分割
- 伪装目标检测
- 显著目标检测
- pytorch_model_hub_mixin
- model_hub_mixin
- transformers
- transformers.js
repo_url: https://github.com/ZhengPeng7/BiRefNet
pipeline_tag: 图像分割
license: mit
双边参考的高分辨率二分图像分割
1 南开大学 2 西北工业大学 3 国防科技大学 4 阿尔托大学 5 上海人工智能实验室 6 特伦托大学
DIS示例1 |
DIS示例2 |
 |
 |
本仓库是论文"双边参考的高分辨率二分图像分割"(CAAI AIR 2024)的官方实现。
访问我们的GitHub仓库:https://github.com/ZhengPeng7/BiRefNet 获取更多细节——代码、文档和模型库!
使用方法
0. 安装依赖包:
pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
1. 加载BiRefNet:
使用HuggingFace的代码+权重
仅使用HuggingFace上的权重——优点:无需手动下载BiRefNet代码;缺点:HuggingFace上的代码可能不是最新版本(我会尽量保持最新)。
from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
使用GitHub代码+HuggingFace权重
仅使用HuggingFace上的权重——优点:代码始终最新;缺点:需要从我的GitHub克隆BiRefNet仓库。
# 下载代码
git clone https://github.com/ZhengPeng7/BiRefNet.git
cd BiRefNet
from models.birefnet import BiRefNet
birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
使用GitHub代码+本地权重
同时使用本地代码和权重。
import torch
from utils import check_state_dict
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
使用加载的BiRefNet进行推理
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from models.birefnet import BiRefNet
birefnet = ...
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()
def extract_object(birefnet, imagepath):
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(imagepath)
input_images = transform_image(image).unsqueeze(0).to('cuda').half()
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
return image, mask
plt.axis("off")
plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
plt.show()
2. 本地使用推理端点:
您可能需要点击部署并自行设置端点,这可能会产生一些费用。
import requests
import base64
from io import BytesIO
from PIL import Image
YOUR_HF_TOKEN = 'xxx'
API_URL = "xxx"
headers = {
"Authorization": "Bearer {}".format(YOUR_HF_TOKEN)
}
def base64_to_bytes(base64_string):
# 如果存在数据URI前缀则移除
if "data:image" in base64_string:
base64_string = base64_string.split(",")[1]
# 将Base64字符串解码为字节
image_bytes = base64.b64decode(base64_string)
return image_bytes
def bytes_to_base64(image_bytes):
# 创建BytesIO对象处理图像数据
image_stream = BytesIO(image_bytes)
# 使用Pillow(PIL)打开图像
image = Image.open(image_stream)
return image
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg",
"parameters": {}
})
output_image = bytes_to_base64(base64_to_bytes(output))
output_image
这个用于标准二分图像分割(DIS)的BiRefNet是在DIS-TR上训练并在DIS-TEs和DIS-VD上验证的。
本仓库包含论文"双边参考的高分辨率二分图像分割"(CAAI AIR 2024)的官方模型权重。
本仓库包含我们论文中提出的BiRefNet的权重,该模型在三个任务(DIS、HRSOD和COD)上取得了最先进的性能。
前往我的GitHub页面获取BiRefNet代码和最新更新:https://github.com/ZhengPeng7/BiRefNet :)
尝试我们的在线推理演示:
- 在Colab上进行图像推理:

- 在Hugging Face上使用GUI进行在线推理,可调整分辨率:

- 对给定权重进行推理和评估:

致谢:
- 非常感谢@Freepik为训练更高分辨率BiRefNet模型及更多探索提供的GPU资源支持。
- 非常感谢@fal为训练更好的通用BiRefNet模型提供的GPU资源支持。
- 非常感谢@not-lain在HuggingFace上更好部署我们的BiRefNet模型方面的帮助。
引用
@article{zheng2024birefnet,
title={双边参考的高分辨率二分图像分割},
author={郑鹏 and 高德宏 and 范登平 and 刘力 and Jorma Laaksonen and 欧阳万里 and Nicu Sebe},
journal={CAAI人工智能研究},
volume = {3},
pages = {9150038},
year={2024}
}