语言:
- 韩文
标签:
- trocr
- 图像转文本
许可证: mit
评估指标:
- 词错误率(wer)
- 字符错误率(cer)
示例:
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/random_2.jpg
示例标题: 随机句子1
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/random_6.jpg
示例标题: 随机句子2
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/chatbot_3.jpg
示例标题: 聊天机器人1
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/chatbot_5.jpg
示例标题: 聊天机器人2
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_1.jpg
示例标题: 新闻1
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_3.jpg
示例标题: 新闻2
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/nsmc_1.jpg
示例标题: 电影评论1
- 图片链接: https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/nsmc_2.jpg
示例标题: 电影评论2
韩语TrOCR概念验证模型
概述
由于TrOCR尚未发布包含韩语的多语言模型,我们为概念验证目的训练了韩语模型。建议基于该模型收集更多数据,进行第一阶段的补充训练或作为第二阶段进行微调。
数据收集
文本数据
我们通过处理三类数据集创建训练数据:
- 新闻摘要数据集:https://huggingface.co/datasets/daekeun-ml/naver-news-summarization-ko
- Naver电影情感分类:https://github.com/e9t/nsmc
- 聊天机器人数据集:https://github.com/songys/Chatbot_data
为高效收集数据,使用句子分割库(Kiwi Python封装;https://github.com/bab2min/kiwipiepy)分隔每句话,最终收集了637,401条样本。
图像数据
图像数据采用TrOCR论文中介绍的TextRecognitionDataGenerator(https://github.com/Belval/TextRecognitionDataGenerator)生成。以下是生成图像的代码片段:
python3 ./trdg/run.py -i ocr_dataset_poc.txt -w 5 -t {num_cores} -f 64 -l ko -c {num_samples} -na 2 --output_dir {dataset_dir}
训练
基础模型
编码器模型使用facebook/deit-base-distilled-patch16-384
,解码器模型使用klue/roberta-base
。从microsoft/trocr-base-stage1
加载权重开始训练更为便捷。
参数
采用启发式参数,未进行单独超参数调优:
- 学习率 = 4e-5
- 训练轮次 = 25
- 启用FP16 = True
- 最大长度 = 64
使用方式
inference.py
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer
import requests
from io import BytesIO
from PIL import Image
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("daekeun-ml/ko-trocr-base-nsmc-news-chatbot")
tokenizer = AutoTokenizer.from_pretrained("daekeun-ml/ko-trocr-base-nsmc-news-chatbot")
url = "https://raw.githubusercontent.com/aws-samples/sm-kornlp/main/trocr/sample_imgs/news_1.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
pixel_values = processor(img, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values, max_length=64)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
数据收集和模型训练所需的所有代码已发布在作者Github:
- https://github.com/daekeun-ml/sm-kornlp-usecases/tree/main/trocr