许可协议:apache-2.0
支持语言:
- 英文
任务类型:图像转文本
推理参数:
最大长度:800
基于Nougat的LaTeX模型
基于Nougat的LaTeX模型是在facebook/nougat-base基础上,使用im2latex-100k数据集进行微调,以提升其从图像生成LaTeX代码的能力。
由于Nougat初始编码器的输入图像尺寸不适合数学公式图像片段,可能导致缩放伪影,从而降低LaTeX代码的生成质量。为解决这一问题,基于Nougat的LaTeX模型调整了输入分辨率,并采用自适应填充方法,确保实际场景中的公式图像片段在缩放后尽可能接近训练数据的分辨率。
评估
在从维基百科、arXiv和im2latex-100k收集的图像-公式对数据集上进行了评估,数据集由lukas-blecher整理。
模型 |
标记准确率 ↑ |
归一化编辑距离 ↓ |
pix2tex |
0.5346 |
0.10312 |
pix2tex* |
0.60 |
0.10 |
nougat-latex-based |
0.623850 |
0.06180 |
pix2tex是一种基于ResNet + ViT + 文本解码器架构的模型,由LaTeX-OCR提出。
pix2tex*:数据来自LaTeX-OCR;
pix2tex:使用发布的检查点进行评估;
nougat-latex-based:使用束搜索策略生成结果进行评估。
运行要求
pip install transformers >= 4.34.0
使用方法
推理API小部件有时会截断响应。详情请参阅此问题。如果推理API因bug截断结果,建议自行运行模型。
- 下载仓库
git clone git@github.com:NormXU/nougat-latex-ocr.git
cd ./nougat-latex-ocr
- 推理
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex import NougatLaTexProcessor
model_name = "Norm/nougat-latex-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
tokenizer = NougatTokenizerFast.from_pretrained(model_name)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)
image = Image.open("path/to/latex/image.png")
if not image.mode == "RGB":
image = image.convert('RGB')
pixel_values = latex_processor(image, return_tensors="pt").pixel_values
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
return_tensors="pt").input_ids
with torch.no_grad():
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_length,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
num_beams=5,
bad_words_ids=[[tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)