库名称:transformers
许可证:gemma
流水线标签:图像-文本到文本
额外授权标题:在Hugging Face上访问PaliGemma
额外授权提示:要在Hugging Face上访问PaliGemma,您需要审阅并同意Google的使用许可。为此,请确保您已登录Hugging Face并点击下方按钮。请求将立即处理。
额外授权按钮内容:确认许可
PaliGemma 2 模型卡
模型页面: PaliGemma
Transformers PaliGemma 2 3B权重,预训练时使用448*448输入图像和512个标记的输入/输出文本序列。该模型以bfloat16
格式提供,用于微调。
资源与技术文档:
使用条款: 条款
作者: Google
模型信息
模型概述
PaliGemma 2是对PaliGemma视觉语言模型(VLM)的更新,融合了Gemma 2模型的能力。PaliGemma系列模型受PaLI-3启发,基于SigLIP视觉模型和Gemma 2语言模型等开放组件构建。它接受图像和文本作为输入并生成文本输出,支持多种语言。该模型设计用于在广泛的视觉语言任务(如图像和短视频字幕生成、视觉问答、文本阅读、目标检测和分割)上实现领先的微调性能。
模型架构
PaliGemma 2由Transformer解码器和Vision Transformer图像编码器组成。文本解码器初始化自Gemma 2的2B、9B和27B参数规模。图像编码器初始化自SigLIP-So400m/14。与原始PaliGemma模型类似,PaliGemma 2的训练遵循PaLI-3的配方。
输入与输出
- 输入: 图像和文本字符串,如图像描述提示或问题。
- 输出: 根据输入生成的文本响应,如图像描述、问题答案、目标边界框坐标列表或分割编码词。
引用
@article{
title={PaliGemma 2: A Family of Versatile VLMs for Transfer},
author={Andreas Steiner and André Susano Pinto and Michael Tschannen and Daniel Keysers and Xiao Wang and Yonatan Bitton and Alexey Gritsenko and Matthias Minderer and Anthony Sherbondy and Shangbang Long and Siyang Qin and Reeve Ingle and Emanuele Bugliarello and Sahar Kazemzadeh and Thomas Mesnard and Ibrahim Alabdulmohsin and Lucas Beyer and Xiaohua Zhai},
year={2024},
journal={arXiv preprint arXiv:2412.03555}
}
模型数据
预训练数据集
PaliGemma 2在以下混合数据集上进行预训练:
PaliGemma 2基于Gemma 2,您可以在Gemma 2模型卡中找到Gemma 2预训练数据集的信息。
数据责任过滤
对WebLI应用以下过滤器,旨在训练PaliGemma 2使用安全和负责任的数据:
- 色情图像过滤: 此过滤器移除被视为色情性质的图像。
- 文本安全过滤: 我们识别并过滤掉与不安全文本配对的图像。不安全文本是被认为包含或涉及儿童性虐待图像(CSAI)、色情内容、粗俗内容或其他冒犯性内容的文本。
- 文本毒性过滤: 我们进一步使用Perspective API识别并过滤掉与被视为侮辱、淫秽、仇恨或其他有毒文本配对的图像。
- 文本个人信息过滤: 我们使用Cloud Data Loss Prevention (DLP) API过滤某些个人信息和其他敏感数据,以保护个人隐私。移除了社会安全号码和其他敏感信息类型等标识符。
- 其他方法: 根据我们的政策和实践,基于内容质量和安全性进行过滤。
在Transformers中使用
以下代码片段使用google/paligemma2-3b-pt-448
模型作为参考。这是一个基础模型,建议在下游任务上进行微调后使用。
这里有一个笔记本展示了如何微调PaliGemma 2。
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
model_id = "google/paligemma2-3b-pt-448"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
实现信息
硬件
PaliGemma 2使用最新一代的Tensor Processing Unit(TPU)硬件(TPUv5e)进行训练。
软件
训练使用JAX、Flax、TFDS和big_vision
完成。
JAX允许研究人员利用包括TPU在内的最新硬件,更快更高效地训练大型模型。
TFDS用于访问数据集,Flax用于模型架构。PaliGemma 2的微调代码和推理代码在big_vision
GitHub仓库中发布。
评估信息
基准测试结果
为了验证PaliGemma 2在多种学术任务上的可迁移性,我们在每个任务上对预训练模型进行微调。我们报告不同分辨率的结果,以展示哪些任务受益于提高的分辨率。重要的是,这些任务或数据集都不是预训练数据的一部分,并且它们的图像已明确从网络规模的预训练数据中移除。
PaliGemma 2按模型分辨率和大小的结果
基准测试 |
224-3B |
224-10B |
224-28B |
448-3B |
448-10B |
448-28B |
[AI2D][ai2d] |
74.7 |
83.1 |
83.2 |
76.0 |
84.4 |
84.6 |
[AOKVQA-DA][aokvqa-da] (val) |
64.2 |
68.9 |
70.2 |
67.9 |
70.8 |
71.2 |
[AOKVQA-MC][aokvqa-mc] (val) |
79.7 |
83.7 |
84.7 |
82.5 |
85.9 |
87.0 |
[ActivityNet-CAP][anet-cap] |
34.2 |
35.9 |
- |
- |
- |
- |
[ActivityNet-QA][anet-qa] |
51.3 |
53.2 |
- |
- |
- |
- |
[COCO-35L][coco-35l] (avg34) |
113.9 |
115.8 |
116.5 |
115.8 |
117.2 |
117.2 |
[COCO-35L][coco-35l] (en) |
138.4 |
140.8 |
142.4 |
140.4 |
142.4 |
142.3 |
[COCOcap][coco-cap] |
141.3 |
143.7 |
144.0 |
143.4 |
145.0 |
145.2 |
[ChartQA][chartqa] (aug) |
74.4 |
74.2 |
68.9 |
89.2 |
90.1 |
85.1 |
[ChartQA][chartqa] (human) |
42.0 |
48.4 |
46.8 |
54.0 |
66.4 |
61.3 |
[CountBenchQA][countbenchqa] |
81.0 |
84.0 |
86.4 |
82.0 |
85.3 |
87.4 |
[DocVQA][docvqa] (val) |
39.9 |
43.9 |
44.9 |
73.6 |
76.6 |
76.1 |
[GQA][gqa] |
66.2 |
67.2 |
67.3 |
68.1 |
68.3 |
68.3 |
[InfoVQA][info-vqa] (val) |
25.2 |
33.6 |
36.4 |
37.5 |
47.8 |
46.7 |
[MARVL][marvl] (avg5) |
83.5 |
89.5 |
90.6 |
82.7 |
89.1 |
89.7 |
[MSRVTT-CAP][msrvtt] |
68.5 |
72.1 |
- |
- |
- |
- |
[MSRVTT-QA][msrvtt] |
50.5 |
51.9 |
- |
- |
- |
- |
[MSVD-QA][msvd-qa] |
61.1 |
62.5 |
|
|
|
|