library_name: transformers
tags:
- 图像转文本
- 文本生成推理
license: gemma
datasets:
- ucsahin/pubtables-detection-1500-samples
pipeline_tag: 图像文本转文本
paligemma-3b-mix-448-ft-TableDetection
该模型是基于google/paligemma-3b-mix-448在ucsahin/pubtables-detection-1500-samples数据集上进行混合精度微调的版本。
在评估集上取得了以下结果:
模型详情
- 本模型是一个多模态语言模型,专为图像中表格检测任务微调。模型通过结合图像和文本输入来预测图像中表格的边界框坐标。
- 该模型主要用于自动化图像中的表格检测流程,可应用于文档处理、数据提取和图像分析等领域,这些场景中识别图像内的表格至关重要。
输入要求:
- 图像: 需要输入包含一个或多个表格的图像,支持JPEG或PNG等标准格式。
- 文本提示: 需提供明确指示任务的文本提示。请使用**"detect table"**作为文本提示语。
输出说明:
- 边界框: 模型以特殊标记
<loc[数值]>
输出归一化坐标值,每个检测结果包含四个坐标值(顺序为y_min, x_min, y_max, x_max)及对应标签。将数值除以1024后,y坐标乘以图像高度,x坐标乘以图像宽度,即可得到原始图像尺寸的相对坐标。
若检测成功,模型将输出类似"<loc[值]><loc[值]><loc[值]><loc[值]> table; <loc[值]><loc[值]><loc[值]><loc[值]> table"的文本(检测到的表格数量决定输出段数)。可通过以下脚本将文本输出转换为PASCAL VOC格式的边界框:
import re
def post_process(bbox_text, image_width, image_height):
loc_values_str = [bbox.strip() for bbox in bbox_text.split(";")]
converted_bboxes = []
for loc_value_str in loc_values_str:
loc_values = re.findall(r'<loc(\d+)>', loc_value_str)
loc_values = [int(x) for x in loc_values]
loc_values = loc_values[:4]
loc_values = [value/1024 for value in loc_values]
loc_values = [
int(loc_values[1]*image_width), int(loc_values[0]*image_height),
int(loc_values[3]*image_width), int(loc_values[2]*image_height),
]
converted_bboxes.append(loc_values)
return converted_bboxes
快速开始
在Transformers中按以下方式加载模型:
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
推理示例如下:
prompt = "detect table"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
generation = generation[0][input_len:]
bbox_text = processor.decode(generation, skip_special_tokens=True)
print(bbox_text)
注意: 也可使用bitsandbytes
加载4位或8位量化模型。但需注意模型可能生成需要后处理的输出(例如出现五个<loc[值]>
标记或非"table"标签)。提供的后处理脚本可处理前一种情况。
4位量化模型加载方式:
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, BitsAndBytesConfig
import torch
model_id = "ucsahin/paligemma-3b-mix-448-ft-TableDetection"
device = "cuda:0"
dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype
)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
quantization_config=bnb_config
)
processor = PaliGemmaProcessor.from_pretrained(model_id)
偏差、风险与限制
关于偏差、风险和限制,请参考google/paligemma-3b-mix-448。
训练超参数
训练过程中使用的超参数如下:
- 学习率:0.0001
- 训练批次大小:4
- 评估批次大小:4
- 随机种子:42
- 梯度累积步数:4
- 混合精度:bf16
- 总训练批次大小:16
- 优化器:Adam(beta1=0.9,beta2=0.999,epsilon=1e-08)
- 学习率调度器类型:线性
- 学习率预热步数:5
- 训练轮次:3
训练结果
训练损失 |
轮次 |
步数 |
验证损失 |
2.957 |
0.1775 |
15 |
2.1300 |
1.9656 |
0.3550 |
30 |
1.8421 |
1.6716 |
0.5325 |
45 |
1.6898 |
1.5514 |
0.7101 |
60 |
1.5803 |
1.5851 |
0.8876 |
75 |
1.5271 |
1.4134 |
1.0651 |
90 |
1.4771 |
1.3566 |
1.2426 |
105 |
1.4528 |
1.3093 |
1.4201 |
120 |
1.4227 |
1.2897 |
1.5976 |
135 |
1.4115 |
1.256 |
1.7751 |
150 |
1.4007 |
1.2666 |
1.9527 |
165 |
1.3678 |
1.2213 |
2.1302 |
180 |
1.3744 |
1.0999 |
2.3077 |
195 |
1.3633 |
1.1931 |
2.4852 |
210 |
1.3606 |
1.0722 |
2.6627 |
225 |
1.3619 |
1.1485 |
2.8402 |
240 |
1.3544 |
框架版本
- PEFT 0.11.1
- Transformers 4.42.0.dev0
- Pytorch 2.3.0+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1