许可证: mit
数据集:
- ehovy/race
语言:
- en
评估指标:
- bleu
基础模型:
- google-t5/t5-base
管道标签: 文本到文本生成
库名称: transformers
标签:
- 干扰项生成
- 教育
- 选择题问题
基于T5-base的干扰项生成模型
本仓库包含一个专为干扰项生成任务微调的T5-base模型。该模型利用T5的文本到文本框架和自定义分隔符,通过输入问题、上下文和正确答案,生成三个具有迷惑性的选择题干扰选项。
模型概览
基于PyTorch Lightning框架实现,该模型对预训练的T5-base进行微调。输入序列采用特殊格式:问题、上下文和正确答案通过自定义分隔符组合,输出序列包含三个干扰项。这种方法特别适用于选择题自动生成场景。
数据处理
输入格式
每个输入样本为以下结构的字符串:
问题 {分隔符} 上下文 {分隔符} 正确答案
- 问题: 题干文本
- 上下文: 背景段落
- 正确答案: 正确选项
- 分隔符: 分词器中添加的特殊分隔标记
输出格式
目标输出序列格式为:
干扰项1 {分隔符} 干扰项2 {分隔符} 干扰项3
这种格式支持模型单次生成三个干扰项。
训练参数
- 框架: PyTorch Lightning
- 基础模型: T5-base
- 优化器: 带线性调度器的Adam(使用预热策略)
- 批大小: 32
- 训练轮次: 5
- 学习率: 2e-5
- 分词设置:
- 特殊标记: 自定义
分隔符
被添加到分词器,用于区分输入和输出序列的不同部分
评估指标
采用BLEU分数评估生成的干扰项质量,测试集结果如下:
干扰项 |
BLEU-1 |
BLEU-2 |
BLEU-3 |
BLEU-4 |
干扰项1 |
29.59 |
21.55 |
17.86 |
15.75 |
干扰项2 |
25.21 |
16.81 |
13.00 |
10.78 |
干扰项3 |
23.99 |
15.78 |
12.35 |
10.52 |
数据显示模型生成的干扰项与参考干扰项具有较高的n-gram重叠度。
使用示例
通过Hugging Face的Transformers库调用模型:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
模型名称 = "fares7elsadek/t5-base-distractor-generation"
分词器 = AutoTokenizer.from_pretrained(模型名称)
模型 = AutoModelForSeq2SeqLM.from_pretrained(模型名称)
分隔符 = "<sep>"
def 生成干扰项(问题, 上下文, 正确答案, 最大长度=64):
输入文本 = f"{问题} {分隔符} {上下文} {分隔符} {正确答案}"
输入 = 分词器([输入文本], return_tensors="pt", truncation=True, padding=True)
输出 = 模型.generate(
input_ids=输入["input_ids"],
attention_mask=输入["attention_mask"],
max_length=最大长度
)
解码结果 = 分词器.decode(输出[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
干扰项列表 = [项.strip() for 项 in 解码结果.split(分隔符)]
return 干扰项列表
问题 = "法国的首都是哪里?"
上下文 = "法国是西欧国家,以丰富的历史和文化遗产闻名。"
正确答案 = "巴黎"
print(生成干扰项(问题, 上下文, 正确答案))