许可证:apache-2.0
语言:
- 高棉语
指标:
- 准确率
基础模型:
- google/mt5-small
任务标签:文本摘要
库名称:transformers
高棉语mT5摘要模型(1024词元)- V2版
简介
本仓库包含高棉语mT5摘要模型的改进版本songhieng/khmer-mt5-summarization-1024tk-V2。该版本在扩展数据集上进行了训练,包含来自kimleang123/rfi_news的数据,从而提升了高棉语文本的摘要性能。
模型详情
- 基础模型:
google/mt5-small
- 微调目标:支持长文本输入的高棉语摘要
- 训练数据集:
kimleang123/rfi_news
+ 原有数据集
- 框架:Hugging Face
transformers
- 任务类型:序列到序列(Seq2Seq)
- 输入:最长1024词元的高棉语文本(文章、段落或文档)
- 输出:高棉语摘要文本
- 训练硬件:GPU(Tesla T4)
- 评估指标:ROUGE分数
安装与设置
1️⃣ 安装依赖
确保已安装transformers
、torch
和datasets
:
pip install transformers torch datasets
2️⃣ 加载模型
加载并使用微调后的模型:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "songhieng/khmer-mt5-summarization-1024tk-V2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
使用方法
1️⃣ 使用Python代码
def summarize_khmer(text, max_length=150):
input_text = f"summarize: {text}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
summary_ids = model.generate(**inputs, max_length=max_length, num_beams=5, length_penalty=2.0, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
khmer_text = "កម្ពុជាមានប្រជាជនប្រមាណ ១៦ លាននាក់ ហើយវាគឺជាប្រទេសនៅតំបន់អាស៊ីអាគ្នេយ៍។"
summary = summarize_khmer(khmer_text)
print("高棉语摘要:", summary)
2️⃣ 使用Hugging Face流水线
from transformers import pipeline
summarizer = pipeline("summarization", model="songhieng/khmer-mt5-summarization-1024tk-V2")
khmer_text = "កម្ពុជាមានប្រជាជនប្រមាណ ១៦ លាននាក់ ហើយវាគឺជាប្រទេសនៅតំបន់អាស៊ីអាគ្នេយ៍។"
summary = summarizer(khmer_text, max_length=150, min_length=30, do_sample=False)
print("高棉语摘要:", summary[0]['summary_text'])
3️⃣ 通过FastAPI部署为API
from fastapi import FastAPI
app = FastAPI()
@app.post("/summarize/")
def summarize(text: str):
inputs = tokenizer(f"summarize: {text}", return_tensors="pt", truncation=True, max_length=1024)
summary_ids = model.generate(**inputs, max_length=150, num_beams=5, length_penalty=2.0, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {"summary": summary}
模型评估
使用ROUGE分数评估模型,该指标衡量生成摘要与参考摘要的相似度。
from datasets import load_metric
rouge = load_metric("rouge")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
return rouge.compute(predictions=decoded_preds, references=decoded_labels)
trainer.evaluate()
保存与上传模型
微调完成后,可将模型上传至Hugging Face Hub:
model.push_to_hub("songhieng/khmer-mt5-summarization-1024tk-V2")
tokenizer.push_to_hub("songhieng/khmer-mt5-summarization-1024tk-V2")
后续下载方式:
model = AutoModelForSeq2SeqLM.from_pretrained("songhieng/khmer-mt5-summarization-1024tk-V2")
tokenizer = AutoTokenizer.from_pretrained("songhieng/khmer-mt5-summarization-1024tk-V2")
摘要
特性 |
详情 |
基础模型 |
google/mt5-small |
任务 |
文本摘要 |
语言 |
高棉语(ខ្មែរ) |
数据集 |
kimleang123/rfi_news + 原有数据集 |
框架 |
Hugging Face Transformers |
评估指标 |
ROUGE分数 |
部署方式 |
Hugging Face模型库、API(FastAPI)、Python代码 |
贡献
欢迎贡献!如有改进建议,请提交问题或拉取请求。
联系方式
如有疑问,可通过Hugging Face讨论区或仓库问题板块联系。
为高棉语NLP社区构建