语言:
- 英文
许可证: MIT
库名称: transformers
标签:
- 视觉
- 图像转文本
- 图像描述生成
管道标签: 图像转文本
基础模型: Salesforce/blip2-opt-2.7b
VLRM
本仓库包含通过论文《VLRM:视觉语言模型作为图像描述生成的奖励模型》中介绍的强化学习方法微调的BLIP-2 OPT-2.7B模型权重。
经过强化学习调优的模型能够生成更长且更全面的描述,与原始模型相比无需额外计算开销。
更多细节请参见GitHub仓库(待完成)。
运行模型
选项1
从本仓库加载完整模型
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> '一位穿着格子衬衫的女子在佛罗里达州海滩的日落时分,与一只坐在地上的黄色拉布拉多犬握手'
选项2
由于微调层仅占整个模型的一小部分,您可以先加载原始模型,再加载强化学习调优的权重。
步骤1. 加载原始模型
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> '一位女子和一只狗坐在海滩上'
步骤2. 加载强化学习调优权重
可用检查点:
vlrm-blip2-opt-2.7b.pt
(论文中的VLRM)
vlrm-rs-blip2-opt-2.7b.pt
(论文中的VLRM-RS)
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> '一位穿着格子衬衫的女子在佛罗里达州海滩的日落时分,与一只坐在地上的黄色拉布拉多犬握手'