库名称: transformers
许可证: gemma
流水线标签: 图像文本到文本
额外授权标题: 在Hugging Face上访问PaliGemma
额外授权提示: 要访问Hugging Face上的PaliGemma,您需要审阅并同意Google的使用许可。请确保已登录Hugging Face账号并点击下方按钮,请求将立即处理。
额外授权按钮内容: 确认许可
PaliGemma模型卡
模型页面: PaliGemma
Transformers PaliGemma 3B权重,基于448*448输入图像在VQAv2数据集上微调。该模型仅提供float32、bfloat16和float16格式供研究使用。微调配置详见big_vision。
资源与技术文档:
使用条款: 条款
作者: Google
模型信息
模型概览
描述
PaliGemma是一款多功能轻量级视觉语言模型(VLM),灵感源自PaLI-3,基于开放组件如SigLIP视觉模型和Gemma语言模型。它接受图像和文本输入并生成文本输出,支持多语言。专为在图像/短视频描述、视觉问答、文本阅读、目标检测与分割等广泛视觉语言任务上实现顶尖微调性能而设计。
架构
PaliGemma由Transformer解码器和Vision Transformer图像编码器组成,总计30亿参数。文本解码器初始化自Gemma-2B,图像编码器初始化自SigLIP-So400m/14。训练遵循PaLI-3方案。
输入输出
- 输入: 图像及文本字符串(如图像描述提示或问题)
- 输出: 生成的响应文本(如图像描述、问题答案、目标框坐标列表或分割代码)
模型数据
预训练数据集
PaliGemma在以下混合数据集上预训练:
数据责任过滤
对WebLI应用以下过滤以确保数据清洁:
使用方式
PaliGemma是单轮视觉语言模型,不适合对话场景,在针对特定用例微调时表现最佳。
可通过任务前缀(如"detect"或"segment")配置模型任务。预训练模型通过此类前缀结构获得丰富能力(问答、描述、分割等),但设计初衷是通过类似提示结构微调到具体任务。交互测试可使用"mix"系列模型(已在多任务混合上微调)。
详见使用限制章节或访问博客文章获取更多示例。
Transformers应用
以下代码示例使用google/paligemma-3b-mix-224
模型。当前仓库中的模型可能针对其他任务训练,请确保输入与任务匹配。
CPU上运行默认精度(float32
)
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
输出: Un auto azul estacionado frente a un edificio.
CUDA上运行其他精度
为方便起见,仓库提供已转换为bfloat16
和float16
的权重版本,可减少下载体积并避免本地转换。
以下示例在NVIDIA CUDA显卡上运行bfloat16
:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
4-bit/8-bit加载
需安装bitsandbytes
以支持8-bit或4-bit精度推理:
pip install bitsandbytes accelerate
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, quantization_config=quantization_config
).eval()
processor = AutoProcessor.from_pretrained(model_id)
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
实现信息
硬件
PaliGemma使用最新一代TPU硬件(TPUv5e)训练。
软件
训练使用JAX、Flax、TFDS和big_vision
完成。
JAX支持利用TPU等最新硬件加速大模型训练。TFDS用于访问数据集,Flax用于模型架构。PaliGemma微调代码与推理代码发布于big_vision
仓库。
评估信息
基准测试结果
为验证PaliGemma在学术任务上的迁移能力,我们在各任务上微调预训练模型,并训练混合任务的mix模型。报告不同分辨率下的结果以展示任务提升情况。注意这些任务/数据集均未出现在预训练数据中。
混合模型(多任务微调)
基准 |
指标(拆分) |
mix-224 |
mix-448 |
MMVP |
配对准确率 |
46.00 |
45.33 |
POPE |
准确率 (随机/流行/对抗) |
88.00
86.63
85.67
|
89.37
88.40
87.47
|
GQA |
准确率(测试) |
65.20 |
65.47 |
单任务微调
基准 (训练拆分) |
指标 (拆分) |
pt-224 |
pt-448 |
pt-896 |
伦理与安全
评估方法
我们通过结构化评估和内部红队测试评估模型内容政策合规性。红队测试由多