🚀 用于干扰项生成的T5-large模型
本仓库包含一个针对干扰项生成任务进行微调的T5-large模型。该模型借助T5的文本到文本框架以及自定义分隔符标记,通过给定的问题、上下文和正确答案,为多项选择题生成三个合理的干扰项。
🚀 快速开始
你可以使用Hugging Face的Transformers管道按以下方式使用此模型:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-large-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question}{SEP_TOKEN}{correct}{SEP_TOKEN}{context}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
✨ 主要特性
- 基于T5的文本到文本框架,能够根据给定的问题、上下文和正确答案,为多项选择题生成三个合理的干扰项。
- 采用自定义分隔符标记,有效处理输入和目标序列。
📦 安装指南
暂未提及安装相关内容,可参考Hugging Face的Transformers库安装说明进行安装。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-large-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question}{SEP_TOKEN}{correct}{SEP_TOKEN}{context}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
📚 详细文档
模型概述
此实现基于PyTorch Lightning构建,对预训练的T5-base模型进行微调,以生成干扰项选项。模型接受一个经过格式化的单一输入序列,该序列包含问题、上下文和正确答案,并通过自定义标记分隔,然后生成包含三个干扰项的目标序列。这种方法在多项选择题生成任务中尤为有用。
数据处理
输入构建
每个输入样本是一个具有以下格式的单一字符串:
question {SEP_TOKEN} correct {SEP_TOKEN} context
- question:问题文本。
- context:上下文段落。
- correct:正确答案。
- SEP_TOKEN:添加到分词器中的特殊标记,用于分隔不同字段。
目标构建
每个目标样本的构建方式如下:
incorrect1 {SEP_TOKEN} incorrect2 {SEP_TOKEN} incorrect3
这种格式允许模型一次性生成三个干扰项。
训练详情
属性 |
详情 |
框架 |
PyTorch Lightning |
基础模型 |
T5-base |
优化器 |
采用线性调度的Adam优化器(使用预热调度器) |
批量大小 |
32 |
训练轮数 |
5 |
学习率 |
2e-5 |
分词处理 |
输入:最大长度为512个标记;目标:最大长度为64个标记 |
特殊标记 |
自定义的SEP_TOKEN 被添加到分词器中,用于分隔输入和目标序列的不同部分 |
评估指标
模型使用每个生成干扰项的BLEU分数进行评估。以下是在测试集上获得的BLEU分数:
干扰项 |
BLEU-1 |
BLEU-2 |
BLEU-3 |
BLEU-4 |
干扰项1 |
32.29 |
23.85 |
19.86 |
17.53 |
干扰项2 |
26.70 |
17.76 |
14.01 |
11.77 |
干扰项3 |
23.63 |
14.89 |
11.29 |
9.41 |
这些分数表明,与参考干扰项相比,该模型能够生成具有较高n元语法重叠的干扰项。 |
|
|
|
|
🔧 技术细节
- 借助T5的文本到文本框架,结合自定义分隔符标记,实现高效的干扰项生成。
- 采用PyTorch Lightning框架进行模型训练,利用Adam优化器和线性调度策略。
- 对输入和目标序列进行合理的分词处理,确保模型能够有效学习和生成干扰项。
📄 许可证
本项目采用MIT许可证。