语言:
- 英文
指标:
- 准确率
库名称: transformers
标签:
- donut
- kyc
模型描述
Donut是一个端到端(即自包含)的视觉文档理解(VDU)模型,用于通用文档图像理解。Donut的架构相当简洁,由基于Transformer的视觉编码器和文本解码器模块组成。该模型不依赖任何OCR功能相关模块,而是通过视觉编码器从给定文档图像中提取特征。随后的文本解码器将这些特征映射为子词标记序列,以构建所需的结构化格式(如JSON)。每个组件均基于Transformer,因此模型能够以端到端方式轻松训练。

使用范围和限制
本模型专为读取印度KYC证件内容而训练,可分类识别Aadhar卡、PAN卡和选民证信息,同时检测文档朝向及是否为彩色/黑白。输入文档可任意方向摆放,但需提供清晰可读的图像质量。由于训练数据量有限,当前性能可能未达最优,未来版本将增加训练样本并扩展更多KYC证件类型。
训练数据
v1版本使用自定义数据集进行训练,共283张图像,其中199张用于训练,42张用于验证,42张用于测试。训练集包含57张Aadhar样本、57张PAN样本和85张选民证样本。
性能表现
当前准确率:
整体准确率 = 74%
Aadhar = 49%(需排查低准确率原因)
PAN = 94%
选民证 = 76%
推理示例
from transformers import DonutProcessor, VisionEncoderDecoderModel
import re
import cv2
import json
import torch
from tqdm.auto import tqdm
import numpy as np
from donut import JSONParseEvaluator
processor = DonutProcessor.from_pretrained("sourinkarmakar/kyc_v1-donut-demo")
model = VisionEncoderDecoderModel.from_pretrained("sourinkarmakar/kyc_v1-donut-demo")
dataset = glob.glob(os.path.join(basepath, "unseen_samples/*"))
output_list = []
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
img = cv2.imread(sample)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pixel_values = processor(img, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
seq = processor.token2json(seq)
output_list.append(seq)
print(output_list)