模型简介
模型特点
模型能力
使用案例
库名称:transformers
许可证:gemma
流水线标签:图像文本到文本
额外授权标题:在Hugging Face上访问PaliGemma
额外授权提示:要访问Hugging Face上的PaliGemma,您需要审阅并同意Google的使用许可。请确保您已登录Hugging Face并点击下方按钮。请求将立即处理。
额外授权按钮内容:确认许可
PaliGemma模型卡
模型页面: PaliGemma
Transformers PaliGemma 3B权重,基于896*896输入图像在DocVQA数据集上微调。模型提供float32、bfloat16和float16格式,仅供研究使用。微调配置详见big_vision。
资源与技术文档:
使用条款: 条款
作者: Google
模型信息
模型概述
描述
PaliGemma是一款多功能轻量级视觉语言模型(VLM),灵感来自PaLI-3,基于SigLIP视觉模型和Gemma语言模型等开放组件。它接收图像和文本输入并生成文本输出,支持多语言。专为在图像/短视频描述、视觉问答、文本阅读、目标检测与分割等广泛视觉语言任务上实现领先的微调性能而设计。
架构
PaliGemma由Transformer解码器和Vision Transformer图像编码器组成,总计30亿参数。文本解码器初始化自Gemma-2B,图像编码器初始化自SigLIP-So400m/14,训练遵循PaLI-3方案。
输入输出
- 输入: 图像和文本字符串(如图像描述提示或问题)
- 输出: 生成的响应文本(如图像描述、问题答案、目标坐标列表或分割代码)
模型数据
预训练数据集
PaliGemma预训练数据混合包括:
- WebLI: 从公开网络构建的多语言图文数据集,用于获取视觉语义理解、目标定位、多语言等能力
- CC3M-35L: 网页英文图像-替代文本对(Sharma等, 2018),通过Google翻译API扩展至34种语言
- VQ²A-CC3M-35L/VQG-CC3M-35L: VQ2A-CC3M子集,翻译为与CC3M-35L相同的34种语言
- OpenImages: 基于OpenImages数据集的手工规则生成的目标检测问答(Piergiovanni等, 2022)
- WIT: 维基百科收集的图文数据(Srinivasan等, 2021)
数据责任过滤
对WebLI应用以下过滤确保数据清洁:
- 色情图像过滤
- 文本安全过滤(移除涉及CSAI、色情、低俗或攻击性内容)
- 文本毒性过滤(使用Perspective API识别侮辱、淫秽、仇恨等内容)
- 文本个人信息过滤(通过Cloud DLP API移除社保号等敏感信息)
- 其他方法:根据政策的内容质量与安全过滤
使用方式
PaliGemma是单轮视觉语言模型,不适合对话场景,在针对特定用例微调时表现最佳。
可通过任务前缀(如"detect"或"segment")配置模型任务。预训练模型通过此类提示结构获得丰富能力(问答、描述、分割等),但设计初衷是通过类似提示结构微调到具体任务。交互测试可使用"mix"系列模型(已在多任务混合上微调)。
Transformers中使用
以下代码示例使用google/paligemma-3b-mix-224
模型。当前仓库中的模型可能针对其他任务训练,请确保输入与任务匹配。
CPU上运行默认精度(float32)
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# 指示模型生成西班牙语描述
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
输出: Un auto azul estacionado frente a un edificio.
CUDA上运行其他精度
仓库已提供转换为bfloat16
和float16
的权重版本,可减少下载大小并避免本地转换。
以下示例在NVIDIA CUDA显卡上运行bfloat16
:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
4-bit/8-bit加载
需安装bitsandbytes
实现低精度推理:
pip install bitsandbytes accelerate
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, quantization_config=quantization_config
).eval()
processor = AutoProcessor.from_pretrained(model_id)
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
实现信息
硬件
使用最新一代TPU硬件(TPUv5e)训练。
软件
使用JAX、Flax、TFDS和big_vision
训练。微调与推理代码发布于big_vision
仓库。
评估信息
基准结果
为验证PaliGemma在学术任务上的可迁移性,我们在各任务上微调预训练模型,并报告不同分辨率下的结果(所有任务数据均未出现在预训练数据中)。
混合模型(多任务微调)
基准 | 指标(划分) | mix-224 | mix-448 |
---|---|---|---|
MMVP | 配对准确率 | 46.00 | 45.33 |
POPE | 准确率 (随机/流行/对抗) |
88.00/86.63/85.67 | 89.37/88.40/87.47 |
GQA | 准确率(测试) | 65.20 | 65.47 |
(完整基准结果表格因篇幅限制未完整展示,包含图像描述、问答、分割、视频任务等数十项指标)
伦理与安全
评估方法
通过结构化评估和内部红队测试评估内容政策合规性,涵盖儿童安全、内容安全和表征危害等类别,同时使用FairFace数据集等基准进行图像-文本评估。
评估结果
- 伦理安全评估结果符合内部政策阈值
- 使用Perspective API(阈值0.8)测量FairFace图像生成描述的毒性,报告不同人口属性组的最大值/中位数: