许可证: gpl-3.0
语言:
内容摘要
论文摘要指出:
图表在数据分析、可视化关键洞察以及回答关于数据的复杂推理问题时非常受欢迎。为了促进使用自然语言进行基于图表的数据分析,最近引入了几项下游任务,如图表问答和图表摘要。然而,解决这些任务的大多数方法都使用在语言或视觉-语言任务上的预训练,这些任务并未尝试明确建模图表的结构(例如,数据是如何被视觉编码的,以及图表元素之间是如何相互关联的)。为了解决这个问题,我们首先构建了一个包含各种主题和视觉风格的大型图表语料库。然后,我们提出了UniChart,一个用于图表理解和推理的预训练模型。UniChart编码了图表中的相关文本、数据和视觉元素,然后使用基于图表的文本解码器以自然语言生成预期的输出。我们提出了几个特定于图表的预训练任务,包括:(i)提取图表中视觉元素(如条形、线条)和数据的低级任务,以及(ii)获取图表理解和推理能力的高级任务。我们发现,在大型语料库上使用特定于图表的低级和高级任务进行预训练,然后在三个下游任务上进行微调,可以在三个下游任务上实现最先进的性能。
网络演示
如果您想快速尝试我们的模型,可以访问我们在Hugging Face Spaces平台上托管的公共网络演示,界面友好!
图表摘要的输入提示是 <summarize_chart>,数据表生成的输入提示是 <extract_data_table>。
推理
您可以轻松地使用huggingface库进行推理!只需按照以下步骤操作:
- 将 model_name 更改为您偏好的检查点。
- 将 imag_path 更改为您系统上图表示例图像的路径。
- 根据您偏好的任务编写 input_prompt,如下表所示。
任务 |
输入提示 |
图表问答 |
<chartqa> 问题 <s_answer> |
开放图表问答 |
<opencqa> 问题 <s_answer> |
图表摘要 |
<summarize_chart> <s_answer> |
数据表提取 |
<extract_data_table> <s_answer> |
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch, os, re
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_1.png')
model_name = "ahmed-masry/unichart-chartqa-960"
image_path = "/content/chart_example_1.png"
input_prompt = "<chartqa> What is the lowest value in blue bar? <s_answer>"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
image = Image.open(image_path).convert("RGB")
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=4,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = sequence.split("<s_answer>")[1].strip()
print(sequence)
联系方式
如果您对此工作有任何疑问,请联系 Ahmed Masry,使用以下电子邮件地址:amasry17@ku.edu.tr 或 ahmed.elmasry24653@gmail.com。
参考文献
如果您在研究中使用了我们的模型或数据集,请引用我们的论文。
@misc{masry2023unichart,
title={UniChart: A Universal Vision-language Pretrained Model for Chart Comprehension and Reasoning},
author={Ahmed Masry and Parsa Kavehzadeh and Xuan Long Do and Enamul Hoque and Shafiq Joty},
year={2023},
eprint={2305.14761},
archivePrefix={arXiv},
primaryClass={cs.CL}
}